Compare commits

...

18 Commits

Author SHA1 Message Date
mi804
3da625432e path 2026-04-23 18:09:16 +08:00
mi804
002e3cdb74 docs 2026-04-23 18:02:58 +08:00
mi804
29bf66cdc9 Merge branch 'main' of https://github.com/modelscope/DiffSynth-Studio 2026-04-23 17:39:10 +08:00
mi804
a80fb84220 style 2026-04-23 17:31:34 +08:00
mi804
394db06d86 codes 2026-04-23 16:52:59 +08:00
mi804
1186379139 noncover 2026-04-22 21:36:30 +08:00
mi804
f2e3427566 reference audio input 2026-04-22 19:16:04 +08:00
mi804
c53c813c12 ace-step train 2026-04-22 17:58:10 +08:00
mi804
b0680ef711 low_vram 2026-04-22 12:47:38 +08:00
mi804
f5a3201d42 t2m 2026-04-21 20:12:15 +08:00
mi804
95cfb77881 t2m 2026-04-21 19:42:57 +08:00
Qifan Zhang
5c89a15b9a Reorder optimizer and logger calls in training loop (#1404) 2026-04-21 13:45:09 +08:00
mi804
9d09e0431c acestep t2m 2026-04-21 13:16:15 +08:00
mi804
a604d76339 pipeline_t2m 2026-04-17 17:45:52 +08:00
mi804
36c203da57 model-code 2026-04-17 17:06:26 +08:00
Hong Zhang
079e51c9f3 Support JoyAI-Image-Edit (#1393)
* auto intergrate joyimage model

* joyimage pipeline

* train

* ready

* styling

* joyai-image docs

* update readme

* pr review
2026-04-15 16:57:11 +08:00
Zhongjie Duan
8f18e24597 skip audio loading if no audio in video (#1397) 2026-04-15 13:52:10 +08:00
Zhongjie Duan
45d973e87d update to version 2.0.8 (#1394) 2026-04-14 16:58:17 +08:00
97 changed files with 8057 additions and 59 deletions

200
README.md
View File

@@ -34,6 +34,10 @@ We believe that a well-developed open-source code framework can lower the thresh
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **April 23, 2026** ACE-Step open-sourced, welcome a new member to the audio model family! Support includes text-to-music generation, low VRAM inference, and LoRA training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/ACE-Step.md) and [example code](/examples/ace_step/).
- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/).
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
@@ -598,6 +602,143 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
</details>
#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
</details>
#### JoyAI-Image: [/docs/en/Model_Details/JoyAI-Image.md](/docs/en/Model_Details/JoyAI-Image.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 4GB VRAM.
```python
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")
```
</details>
<details>
<summary>Examples</summary>
Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples/joyai_image/)
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
</details>
### Video Synthesis
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
@@ -877,18 +1018,22 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
</details>
#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
### Audio Synthesis
#### ACE-Step: [/docs/en/Model_Details/ACE-Step.md](/docs/en/Model_Details/ACE-Step.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
Running the following code will quickly load the [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
@@ -899,28 +1044,34 @@ vram_config = {
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
device="cuda",
model_configs=[
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
```
</details>
@@ -929,12 +1080,21 @@ image.save("output.jpg")
<summary>Examples</summary>
Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
Example code for ACE-Step is available at: [/examples/ace_step/](/examples/ace_step/)
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
</details>

View File

@@ -34,6 +34,10 @@ DiffSynth 目前包括两个开源项目:
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2026年4月23日** ACE-Step 开源,欢迎加入音频生成模型家族!支持文生音乐推理、低显存推理和 LoRA 训练能力。详情请参考[文档](/docs/zh/Model_Details/ACE-Step.md)和[示例代码](/examples/ace_step/)。
- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。
@@ -598,6 +602,143 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
</details>
#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
model_configs=[
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
```
</details>
<details>
<summary>示例代码</summary>
ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/)
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|-|-|-|-|-|-|-|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
</details>
#### JoyAI-Image: [/docs/zh/Model_Details/JoyAI-Image.md](/docs/zh/Model_Details/JoyAI-Image.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。
```python
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")
```
</details>
<details>
<summary>示例代码</summary>
JoyAI-Image 的示例代码位于:[/examples/joyai_image/](/examples/joyai_image/)
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
</details>
### 视频生成模型
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
@@ -877,18 +1018,22 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
</details>
#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md)
### 音频生成模型
#### ACE-Step: [/docs/zh/Model_Details/ACE-Step.md](/docs/zh/Model_Details/ACE-Step.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
运行以下代码可以快速加载 [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
```python
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
@@ -899,28 +1044,34 @@ vram_config = {
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = ErnieImagePipeline.from_pretrained(
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device='cuda',
device="cuda",
model_configs=[
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="一只黑白相间的中华田园犬",
negative_prompt="",
height=1024,
width=1024,
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=50,
cfg_scale=4.0,
)
image.save("output.jpg")
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
```
</details>
@@ -929,12 +1080,21 @@ image.save("output.jpg")
<summary>示例代码</summary>
ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/)
ACE-Step 的示例代码位于:[/examples/ace_step/](/examples/ace_step/)
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|-|-|-|-|-|-|-|
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
</details>

View File

@@ -900,4 +900,102 @@ mova_series = [
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series
joyai_image_series = [
{
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
"model_hash": "56592ddfd7d0249d3aa527d24161a863",
"model_name": "joyai_image_dit",
"model_class": "diffsynth.models.joyai_image_dit.JoyAIImageDiT",
},
{
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model-*.safetensors")
"model_hash": "2d11bf14bba8b4e87477c8199a895403",
"model_name": "joyai_image_text_encoder",
"model_class": "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_text_encoder.JoyAIImageTextEncoderStateDictConverter",
},
]
ace_step_series = [
# === Standard DiT variants (24 layers, hidden_size=2048) ===
# Covers: turbo, turbo-shift1, turbo-shift3, turbo-continuous, base, sft
# All share identical state_dict structure → same hash
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_dit",
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
},
# === XL DiT variants (32 layers, hidden_size=2560) ===
# Covers: xl-base, xl-sft, xl-turbo
{
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_dit",
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
"extra_kwargs": {
"hidden_size": 2560,
"intermediate_size": 9728,
"num_hidden_layers": 32,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"encoder_hidden_size": 2048,
"layer_types": ["sliding_attention", "full_attention"] * 16,
},
},
# === Conditioner (shared by all DiT variants, same architecture) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_conditioner",
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
},
# === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) ===
{
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_conditioner",
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
},
# === Qwen3-Embedding (text encoder) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors")
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
"model_name": "ace_step_text_encoder",
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter",
},
# === VAE (AutoencoderOobleck CNN) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "51420834e54474986a7f4be0e4d6f687",
"model_name": "ace_step_vae",
"model_class": "diffsynth.models.ace_step_vae.AceStepVAE",
},
# === Tokenizer (VAE latent discretization: tokenizer + detokenizer) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_tokenizer",
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
},
# === XL Tokenizer (XL models share same tokenizer architecture) ===
{
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_tokenizer",
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
},
]
MODEL_CONFIGS = (
qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
)

View File

@@ -279,6 +279,60 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.joyai_image_dit.Transformer3DModel": {
"diffsynth.models.joyai_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.joyai_image_dit.ModulateWan": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionModel": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
# ACE-Step module maps
"diffsynth.models.ace_step_dit.AceStepDiTModel": {
"diffsynth.models.ace_step_dit.AceStepDiTLayer": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ace_step_conditioner.AceStepConditionEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ace_step_text_encoder.AceStepTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ace_step_vae.AceStepVAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.ace_step_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ace_step_tokenizer.AceStepTokenizer": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"vector_quantize_pytorch.ResidualFSQ": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}
def QwenImageTextEncoder_Module_Map_Updater():

View File

@@ -1,8 +1,9 @@
import math
import math, warnings
import torch, torchvision, imageio, os
import imageio.v3 as iio
from PIL import Image
import torchaudio
from diffsynth.utils.data.audio import read_audio
class DataProcessingPipeline:
@@ -260,15 +261,43 @@ class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
def __call__(self, data: str):
reader = self.get_reader(data)
num_frames = self.get_num_frames(reader)
duration = num_frames / self.frame_rate
waveform, sample_rate = torchaudio.load(data)
target_samples = int(duration * sample_rate)
current_samples = waveform.shape[-1]
if current_samples > target_samples:
waveform = waveform[..., :target_samples]
elif current_samples < target_samples:
padding = target_samples - current_samples
waveform = torch.nn.functional.pad(waveform, (0, padding))
return waveform, sample_rate
try:
reader = self.get_reader(data)
num_frames = self.get_num_frames(reader)
duration = num_frames / self.frame_rate
waveform, sample_rate = torchaudio.load(data)
target_samples = int(duration * sample_rate)
current_samples = waveform.shape[-1]
if current_samples > target_samples:
waveform = waveform[..., :target_samples]
elif current_samples < target_samples:
padding = target_samples - current_samples
waveform = torch.nn.functional.pad(waveform, (0, padding))
return waveform, sample_rate
except:
warnings.warn(f"Cannot load audio in {data}. The audio will be `None`.")
return None
class LoadPureAudioWithTorchaudio(DataProcessingOperator):
def __init__(self, target_sample_rate=None, target_duration=None):
self.target_sample_rate = target_sample_rate
self.target_duration = target_duration
self.resample = True if target_sample_rate is not None else False
def __call__(self, data: str):
try:
waveform, sample_rate = read_audio(data, resample=self.resample, resample_rate=self.target_sample_rate)
if self.target_duration is not None:
target_samples = int(self.target_duration * sample_rate)
current_samples = waveform.shape[-1]
if current_samples > target_samples:
waveform = waveform[..., :target_samples]
elif current_samples < target_samples:
padding = target_samples - current_samples
waveform = torch.nn.functional.pad(waveform, (0, padding))
return waveform, sample_rate
except Exception as e:
warnings.warn(f"Cannot load audio in '{data}' due to '{e}'. The audio will be `None`.")
return None

View File

@@ -152,7 +152,7 @@ class BasePipeline(torch.nn.Module):
# remove batch dim
if audio_output.ndim == 3:
audio_output = audio_output.squeeze(0)
return audio_output.float()
return audio_output.float().cpu()
def load_models_to_device(self, model_names):
if self.vram_management_enabled:

View File

@@ -4,7 +4,7 @@ from typing_extensions import Literal
class FlowMatchScheduler():
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"):
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image", "ACE-Step"] = "FLUX.1"):
self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan,
@@ -14,6 +14,7 @@ class FlowMatchScheduler():
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
"ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
"ACE-Step": FlowMatchScheduler.set_timesteps_ace_step,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@@ -142,6 +143,26 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_ace_step(num_inference_steps=8, denoising_strength=1.0, shift=3.0):
"""ACE-Step Flow Matching scheduler.
Timesteps range from 1.0 to 0.0 (not multiplied by 1000).
Shift transformation: t = shift * t / (1 + (shift - 1) * t)
Args:
num_inference_steps: Number of diffusion steps.
denoising_strength: Denoising strength (1.0 = full denoising).
shift: Timestep shift parameter (default 3.0 for turbo).
"""
num_train_timesteps = 1000
sigma_start = denoising_strength
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
if shift is not None and shift != 1.0:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
sigma_min = 0.0
@@ -159,6 +180,18 @@ class FlowMatchScheduler():
timesteps[timestep_id] = timestep
return sigmas, timesteps
@staticmethod
def set_timesteps_joyai_image(num_inference_steps=100, denoising_strength=1.0, shift=None):
sigma_min = 0.0
sigma_max = 1.0
shift = 4.0 if shift is None else shift
num_train_timesteps = 1000
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
num_train_timesteps = 1000

View File

@@ -33,15 +33,15 @@ def launch_training_task(
for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step()
optimizer.zero_grad()
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)

View File

@@ -0,0 +1,695 @@
# Copyright 2025 The ACESTEO Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
from ..core.attention import attention_forward
from ..core.gradient import gradient_checkpoint_forward
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutput
from transformers.processing_utils import Unpack
from transformers.utils import can_return_tuple, logging
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3MLP,
Qwen3RMSNorm,
Qwen3RotaryEmbedding,
apply_rotary_pos_emb,
)
logger = logging.get_logger(__name__)
def create_4d_mask(
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
is_sliding_window: bool = False,
is_causal: bool = True,
) -> torch.Tensor:
indices = torch.arange(seq_len, device=device)
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
if is_causal:
valid_mask = valid_mask & (diff >= 0)
if is_sliding_window and sliding_window is not None:
if is_causal:
valid_mask = valid_mask & (diff <= sliding_window)
else:
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
valid_mask = valid_mask & padding_mask_4d
min_dtype = torch.finfo(dtype).min
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
mask_tensor.masked_fill_(valid_mask, 0.0)
return mask_tensor
def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
hidden_cat = torch.cat([hidden1, hidden2], dim=1)
mask_cat = torch.cat([mask1, mask2], dim=1)
B, L, D = hidden_cat.shape
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True)
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
lengths = mask_cat.sum(dim=1)
new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
return hidden_left, new_mask
class Lambda(nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
class AceStepAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
is_cross_attention: bool = False,
is_causal: bool = False,
):
super().__init__()
self.layer_idx = layer_idx
self.head_dim = head_dim or hidden_size // num_attention_heads
self.num_key_value_groups = num_attention_heads // num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.attention_dropout = attention_dropout
if is_cross_attention:
is_causal = False
self.is_causal = is_causal
self.is_cross_attention = is_cross_attention
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.attention_type = layer_types[layer_idx]
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
if is_cross_attention:
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
curr_past_key_value = past_key_value.cross_attention_cache
if not is_updated:
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
past_key_value.is_updated[self.layer_idx] = True
else:
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
else:
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if self.num_key_value_groups > 1:
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
attn_output = attention_forward(
query_states, key_states, value_states,
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
attn_mask=attention_mask,
)
attn_weights = None
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class AceStepEncoderLayer(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.layer_idx = layer_idx
self.self_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
is_cross_attention=False,
is_causal=False,
)
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
mlp_config = type('Config', (), {
'hidden_size': hidden_size,
'intermediate_size': intermediate_size,
'hidden_act': 'silu',
})()
self.mlp = Qwen3MLP(mlp_config)
self.attention_type = layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=False,
past_key_value=None,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class AceStepLyricEncoder(nn.Module):
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
use_cache: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
text_hidden_dim: int = 1024,
num_lyric_encoder_hidden_layers: int = 8,
**kwargs,
):
super().__init__()
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
self.text_hidden_dim = text_hidden_dim
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.use_cache = use_cache
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.embed_tokens = nn.Linear(text_hidden_dim, hidden_size)
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
rope_config = type('RopeConfig', (), {
'hidden_size': hidden_size,
'num_attention_heads': num_attention_heads,
'num_key_value_heads': num_key_value_heads,
'head_dim': head_dim,
'max_position_embeddings': max_position_embeddings,
'rope_theta': rope_theta,
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
'rms_norm_eps': rms_norm_eps,
'attention_bias': attention_bias,
'attention_dropout': attention_dropout,
'hidden_act': 'silu',
'intermediate_size': intermediate_size,
'layer_types': self.layer_types,
'sliding_window': sliding_window,
'_attn_implementation': self._attn_implementation,
})()
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
self.gradient_checkpointing = False
self.layers = nn.ModuleList([
AceStepEncoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=self.layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
for layer_idx in range(num_lyric_encoder_hidden_layers)
])
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutput:
output_attentions = output_attentions if output_attentions is not None else False
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
assert input_ids is None, "Only `inputs_embeds` is supported for the lyric encoder."
assert attention_mask is not None, "Attention mask must be provided for the lyric encoder."
assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder."
inputs_embeds = self.embed_tokens(inputs_embeds)
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
seq_len = inputs_embeds.shape[1]
dtype = inputs_embeds.dtype
device = inputs_embeds.device
full_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=None,
is_sliding_window=False, is_causal=False
)
sliding_attn_mask = None
if self.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=self.sliding_window,
is_sliding_window=True, is_causal=False
)
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for layer_module in self.layers[: self.num_lyric_encoder_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer_module(
hidden_states, position_embeddings,
self_attn_mask_mapping[layer_module.attention_type],
position_ids, output_attentions,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class AceStepTimbreEncoder(nn.Module):
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
use_cache: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
timbre_hidden_dim: int = 64,
num_timbre_encoder_hidden_layers: int = 4,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.use_cache = use_cache
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.timbre_hidden_dim = timbre_hidden_dim
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.embed_tokens = nn.Linear(timbre_hidden_dim, hidden_size)
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
rope_config = type('RopeConfig', (), {
'hidden_size': hidden_size,
'num_attention_heads': num_attention_heads,
'num_key_value_heads': num_key_value_heads,
'head_dim': head_dim,
'max_position_embeddings': max_position_embeddings,
'rope_theta': rope_theta,
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
'rms_norm_eps': rms_norm_eps,
'attention_bias': attention_bias,
'attention_dropout': attention_dropout,
'hidden_act': 'silu',
'intermediate_size': intermediate_size,
'layer_types': self.layer_types,
'sliding_window': sliding_window,
'_attn_implementation': self._attn_implementation,
})()
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
self.gradient_checkpointing = False
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size))
self.layers = nn.ModuleList([
AceStepEncoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=self.layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
for layer_idx in range(num_timbre_encoder_hidden_layers)
])
def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
N, d = timbre_embs_packed.shape
device = timbre_embs_packed.device
dtype = timbre_embs_packed.dtype
B = int(refer_audio_order_mask.max().item() + 1)
counts = torch.bincount(refer_audio_order_mask, minlength=B)
max_count = counts.max().item()
sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
sorted_batch_ids = refer_audio_order_mask[sorted_indices]
positions = torch.arange(N, device=device)
batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]])
positions_in_sorted = positions - batch_starts[sorted_batch_ids]
inverse_indices = torch.empty_like(sorted_indices)
inverse_indices[sorted_indices] = torch.arange(N, device=device)
positions_in_batch = positions_in_sorted[inverse_indices]
indices_2d = refer_audio_order_mask * max_count + positions_in_batch
one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype)
timbre_embs_flat = one_hot.t() @ timbre_embs_packed
timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
mask_flat = (one_hot.sum(dim=0) > 0).long()
new_mask = mask_flat.reshape(B, max_count)
return timbre_embs_unpack, new_mask
@can_return_tuple
def forward(
self,
refer_audio_acoustic_hidden_states_packed: Optional[torch.FloatTensor] = None,
refer_audio_order_mask: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutput:
inputs_embeds = refer_audio_acoustic_hidden_states_packed
inputs_embeds = self.embed_tokens(inputs_embeds)
seq_len = inputs_embeds.shape[1]
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
dtype = inputs_embeds.dtype
device = inputs_embeds.device
full_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=None,
is_sliding_window=False, is_causal=False
)
sliding_attn_mask = None
if self.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=self.sliding_window,
is_sliding_window=True, is_causal=False
)
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer_module in self.layers[: self.num_timbre_encoder_hidden_layers]:
layer_outputs = layer_module(
hidden_states, position_embeddings,
self_attn_mask_mapping[layer_module.attention_type],
position_ids,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states[:, 0, :]
# For packed input: reshape [1, T, D] -> [T, D] for unpacking
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
return timbre_embs_unpack, timbre_embs_mask
class AceStepConditionEncoder(nn.Module):
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
use_cache: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
text_hidden_dim: int = 1024,
timbre_hidden_dim: int = 64,
num_lyric_encoder_hidden_layers: int = 8,
num_timbre_encoder_hidden_layers: int = 4,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.use_cache = use_cache
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.text_hidden_dim = text_hidden_dim
self.timbre_hidden_dim = timbre_hidden_dim
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.text_projector = nn.Linear(text_hidden_dim, hidden_size, bias=False)
self.null_condition_emb = nn.Parameter(torch.randn(1, 1, hidden_size))
self.lyric_encoder = AceStepLyricEncoder(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
use_sliding_window=use_sliding_window,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
text_hidden_dim=text_hidden_dim,
num_lyric_encoder_hidden_layers=num_lyric_encoder_hidden_layers,
)
self.timbre_encoder = AceStepTimbreEncoder(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
use_sliding_window=use_sliding_window,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
timbre_hidden_dim=timbre_hidden_dim,
num_timbre_encoder_hidden_layers=num_timbre_encoder_hidden_layers,
)
def forward(
self,
text_hidden_states: Optional[torch.FloatTensor] = None,
text_attention_mask: Optional[torch.Tensor] = None,
lyric_hidden_states: Optional[torch.LongTensor] = None,
lyric_attention_mask: Optional[torch.Tensor] = None,
reference_latents: Optional[torch.Tensor] = None,
refer_audio_order_mask: Optional[torch.LongTensor] = None,
):
text_hidden_states = self.text_projector(text_hidden_states)
lyric_encoder_outputs = self.lyric_encoder(
inputs_embeds=lyric_hidden_states,
attention_mask=lyric_attention_mask,
)
lyric_hidden_states = lyric_encoder_outputs.last_hidden_state
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(reference_latents, refer_audio_order_mask)
encoder_hidden_states, encoder_attention_mask = pack_sequences(
lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
)
encoder_hidden_states, encoder_attention_mask = pack_sequences(
encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask
)
return encoder_hidden_states, encoder_attention_mask

View File

@@ -0,0 +1,901 @@
# Copyright 2025 The ACESTEO Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from ..core.attention.attention import attention_forward
from ..core import gradient_checkpoint_forward
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutput
from transformers.processing_utils import Unpack
from transformers.utils import logging
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3MLP,
Qwen3RMSNorm,
Qwen3RotaryEmbedding,
apply_rotary_pos_emb,
)
logger = logging.get_logger(__name__)
def create_4d_mask(
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None, # [Batch, Seq_Len]
sliding_window: Optional[int] = None,
is_sliding_window: bool = False,
is_causal: bool = True,
) -> torch.Tensor:
"""
General 4D Attention Mask generator compatible with CPU/Mac/SDPA and Eager mode.
Supports use cases:
1. Causal Full: is_causal=True, is_sliding_window=False (standard GPT)
2. Causal Sliding: is_causal=True, is_sliding_window=True (Mistral/Qwen local window)
3. Bidirectional Full: is_causal=False, is_sliding_window=False (BERT/Encoder)
4. Bidirectional Sliding: is_causal=False, is_sliding_window=True (Longformer local)
Returns:
[Batch, 1, Seq_Len, Seq_Len] additive mask (0.0 for keep, -inf for mask)
"""
# ------------------------------------------------------
# 1. Construct basic geometry mask [Seq_Len, Seq_Len]
# ------------------------------------------------------
# Build index matrices
# i (Query): [0, 1, ..., L-1]
# j (Key): [0, 1, ..., L-1]
indices = torch.arange(seq_len, device=device)
# diff = i - j
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
# Initialize all True (all positions visible)
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
# (A) Handle causality (Causal)
if is_causal:
# i >= j => diff >= 0
valid_mask = valid_mask & (diff >= 0)
# (B) Handle sliding window
if is_sliding_window and sliding_window is not None:
if is_causal:
# Causal sliding: only attend to past window steps
# i - j <= window => diff <= window
# (diff >= 0 already handled above)
valid_mask = valid_mask & (diff <= sliding_window)
else:
# Bidirectional sliding: attend past and future window steps
# |i - j| <= window => abs(diff) <= sliding_window
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
# Expand dimensions to [1, 1, Seq_Len, Seq_Len] for broadcasting
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
# ------------------------------------------------------
# 2. Apply padding mask (Key Masking)
# ------------------------------------------------------
if attention_mask is not None:
# attention_mask shape: [Batch, Seq_Len] (1=valid, 0=padding)
# We want to mask out invalid keys (columns)
# Expand shape: [Batch, 1, 1, Seq_Len]
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
# Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
# Result shape: [B, 1, L, L]
valid_mask = valid_mask & padding_mask_4d
# ------------------------------------------------------
# 3. Convert to additive mask
# ------------------------------------------------------
# Get the minimal value for current dtype
min_dtype = torch.finfo(dtype).min
# Create result tensor filled with -inf by default
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
# Set valid positions to 0.0
mask_tensor.masked_fill_(valid_mask, 0.0)
return mask_tensor
def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
"""
Pack two sequences by concatenating and sorting them based on mask values.
Args:
hidden1: First hidden states tensor of shape [B, L1, D]
hidden2: Second hidden states tensor of shape [B, L2, D]
mask1: First mask tensor of shape [B, L1]
mask2: Second mask tensor of shape [B, L2]
Returns:
Tuple of (packed_hidden_states, new_mask) where:
- packed_hidden_states: Packed hidden states with valid tokens (mask=1) first, shape [B, L1+L2, D]
- new_mask: New mask tensor indicating valid positions, shape [B, L1+L2]
"""
# Step 1: Concatenate hidden states and masks along sequence dimension
hidden_cat = torch.cat([hidden1, hidden2], dim=1) # [B, L, D]
mask_cat = torch.cat([mask1, mask2], dim=1) # [B, L]
B, L, D = hidden_cat.shape
# Step 2: Sort indices so that mask values of 1 come before 0
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) # [B, L]
# Step 3: Reorder hidden states using sorted indices
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
# Step 4: Create new mask based on valid sequence lengths
lengths = mask_cat.sum(dim=1) # [B]
new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
return hidden_left, new_mask
class TimestepEmbedding(nn.Module):
"""
Timestep embedding module for diffusion models.
Converts timestep values into high-dimensional embeddings using sinusoidal
positional encoding, followed by MLP layers. Used for conditioning diffusion
models on timestep information.
"""
def __init__(
self,
in_channels: int,
time_embed_dim: int,
scale: float = 1,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act1 = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
self.in_channels = in_channels
self.act2 = nn.SiLU()
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
self.scale = scale
def timestep_embedding(self, t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t: A 1-D tensor of N indices, one per batch element. These may be fractional.
dim: The dimension of the output embeddings.
max_period: Controls the minimum frequency of the embeddings.
Returns:
An (N, D) tensor of positional embeddings.
"""
t = t * self.scale
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.in_channels)
temb = self.linear_1(t_freq.to(t.dtype))
temb = self.act1(temb)
temb = self.linear_2(temb)
timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1))
return temb, timestep_proj
class AceStepAttention(nn.Module):
"""
Multi-headed attention module for AceStep model.
Implements the attention mechanism from 'Attention Is All You Need' paper,
with support for both self-attention and cross-attention modes. Uses RMSNorm
for query and key normalization, and supports sliding window attention for
efficient long-sequence processing.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
is_cross_attention: bool = False,
is_causal: bool = False,
):
super().__init__()
self.layer_idx = layer_idx
self.head_dim = head_dim or hidden_size // num_attention_heads
self.num_key_value_groups = num_attention_heads // num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.attention_dropout = attention_dropout
if is_cross_attention:
is_causal = False
self.is_causal = is_causal
self.is_cross_attention = is_cross_attention
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.attention_type = layer_types[layer_idx]
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
# Project and normalize query states
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
# Determine if this is cross-attention (requires encoder_hidden_states)
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
# Cross-attention path: attend to encoder hidden states
if is_cross_attention:
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
# After the first generated token, we can reuse all key/value states from cache
curr_past_key_value = past_key_value.cross_attention_cache
# Conditions for calculating key and value states
if not is_updated:
# Compute and cache K/V for the first time
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
# Update cache: save all key/value states to cache for fast auto-regressive generation
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
# Set flag that this layer's cross-attention cache is updated
past_key_value.is_updated[self.layer_idx] = True
else:
# Reuse cached key/value states for subsequent tokens
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
# No cache used, compute K/V directly
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
# Self-attention path: attend to the same sequence
else:
# Project and normalize key/value states for self-attention
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
# Apply rotary position embeddings (RoPE) if provided
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Update cache for auto-regressive generation
if past_key_value is not None:
# Sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# GGA expansion: if num_key_value_heads < num_attention_heads
if self.num_key_value_groups > 1:
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
# Use DiffSynth unified attention
# Tensors are already in (batch, heads, seq, dim) format -> "b n s d"
attn_output = attention_forward(
query_states, key_states, value_states,
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
attn_mask=attention_mask,
)
attn_weights = None # attention_forward doesn't return weights
# Flatten and project output: (B, n_heads, seq, dim) -> (B, seq, n_heads*dim)
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class AceStepEncoderLayer(nn.Module):
"""
Encoder layer for AceStep model.
Consists of self-attention and MLP (feed-forward) sub-layers with residual connections.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
intermediate_size: int = 6144,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: list = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.layer_idx = layer_idx
self.self_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
is_cross_attention=False,
is_causal=False,
)
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# MLP (feed-forward) sub-layer
self.mlp = Qwen3MLP(
config=type('Config', (), {
'hidden_size': hidden_size,
'intermediate_size': intermediate_size,
'hidden_act': 'silu',
})()
)
self.attention_type = layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[
torch.FloatTensor,
Optional[tuple[torch.FloatTensor, torch.FloatTensor]],
]:
# Self-attention with residual connection
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
# Encoders don't use cache
use_cache=False,
past_key_value=None,
**kwargs,
)
hidden_states = residual + hidden_states
# MLP with residual connection
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class AceStepDiTLayer(nn.Module):
"""
DiT (Diffusion Transformer) layer for AceStep model.
Implements a transformer layer with three main components:
1. Self-attention with adaptive layer norm (AdaLN)
2. Cross-attention (optional) for conditioning on encoder outputs
3. Feed-forward MLP with adaptive layer norm
Uses scale-shift modulation from timestep embeddings for adaptive normalization.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
intermediate_size: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
use_cross_attention: bool = True,
):
super().__init__()
self.self_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.self_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
self.use_cross_attention = use_cross_attention
if self.use_cross_attention:
self.cross_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.cross_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
is_cross_attention=True,
)
self.mlp_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.mlp = Qwen3MLP(
config=type('Config', (), {
'hidden_size': hidden_size,
'intermediate_size': intermediate_size,
'hidden_act': 'silu',
})()
)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, hidden_size) / hidden_size**0.5)
self.attention_type = layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# Extract scale-shift parameters for adaptive layer norm from timestep embeddings
# 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table.to(temb.device) + temb
).chunk(6, dim=1)
# Step 1: Self-attention with adaptive layer norm (AdaLN)
# Apply adaptive normalization: norm(x) * (1 + scale) + shift
norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output, self_attn_weights = self.self_attn(
hidden_states=norm_hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=False,
past_key_value=None,
**kwargs,
)
# Apply gated residual connection: x = x + attn_output * gate
hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)
# Step 2: Cross-attention (if enabled) for conditioning on encoder outputs
if self.use_cross_attention:
norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states)
attn_output, cross_attn_weights = self.cross_attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
# Standard residual connection for cross-attention
hidden_states = hidden_states + attn_output
# Step 3: Feed-forward (MLP) with adaptive layer norm
# Apply adaptive normalization for MLP: norm(x) * (1 + scale) + shift
norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
ff_output = self.mlp(norm_hidden_states)
# Apply gated residual connection: x = x + mlp_output * gate
hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
class Lambda(nn.Module):
"""
Wrapper module for arbitrary lambda functions.
Allows using lambda functions in nn.Sequential by wrapping them in a Module.
Useful for simple transformations like transpose operations.
"""
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
class AceStepDiTModel(nn.Module):
"""
DiT (Diffusion Transformer) model for AceStep.
Main diffusion model that generates audio latents conditioned on text, lyrics,
and timbre. Uses patch-based processing with transformer layers, timestep
conditioning, and cross-attention to encoder outputs.
"""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
use_cache: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
patch_size: int = 2,
in_channels: int = 192,
audio_acoustic_hidden_dim: int = 64,
encoder_hidden_size: Optional[int] = None,
**kwargs,
):
super().__init__()
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.use_cache = use_cache
encoder_hidden_size = encoder_hidden_size or hidden_size
# Rotary position embeddings for transformer layers
rope_config = type('RopeConfig', (), {
'hidden_size': hidden_size,
'num_attention_heads': num_attention_heads,
'num_key_value_heads': num_key_value_heads,
'head_dim': head_dim,
'max_position_embeddings': max_position_embeddings,
'rope_theta': rope_theta,
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
'rms_norm_eps': rms_norm_eps,
'attention_bias': attention_bias,
'attention_dropout': attention_dropout,
'hidden_act': 'silu',
'intermediate_size': intermediate_size,
'layer_types': self.layer_types,
'sliding_window': sliding_window,
})()
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
# Stack of DiT transformer layers
self.layers = nn.ModuleList([
AceStepDiTLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
intermediate_size=intermediate_size,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=self.layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
for layer_idx in range(num_hidden_layers)
])
self.patch_size = patch_size
# Input projection: patch embedding using 1D convolution
self.proj_in = nn.Sequential(
Lambda(lambda x: x.transpose(1, 2)),
nn.Conv1d(
in_channels=in_channels,
out_channels=hidden_size,
kernel_size=patch_size,
stride=patch_size,
padding=0,
),
Lambda(lambda x: x.transpose(1, 2)),
)
# Timestep embeddings for diffusion conditioning
self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
# Project encoder hidden states to model dimension
self.condition_embedder = nn.Linear(encoder_hidden_size, hidden_size, bias=True)
# Output normalization and projection
self.norm_out = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.proj_out = nn.Sequential(
Lambda(lambda x: x.transpose(1, 2)),
nn.ConvTranspose1d(
in_channels=hidden_size,
out_channels=audio_acoustic_hidden_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0,
),
Lambda(lambda x: x.transpose(1, 2)),
)
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, hidden_size) / hidden_size**0.5)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
timestep_r: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
context_latents: torch.Tensor,
use_cache: Optional[bool] = False,
past_key_values: Optional[EncoderDecoderCache] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
return_hidden_states: int = None,
custom_layers_config: Optional[dict] = None,
enable_early_exit: bool = False,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
):
use_cache = use_cache if use_cache is not None else self.use_cache
# Disable cache during training or when gradient checkpointing is enabled
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if self.training:
use_cache = False
# Initialize cache if needed (only during inference for auto-regressive generation)
if not self.training and use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
# Compute timestep embeddings for diffusion conditioning
# Two embeddings: one for timestep t, one for timestep difference (t - r)
temb_t, timestep_proj_t = self.time_embed(timestep)
temb_r, timestep_proj_r = self.time_embed_r(timestep - timestep_r)
# Combine embeddings
temb = temb_t + temb_r
timestep_proj = timestep_proj_t + timestep_proj_r
# Concatenate context latents (source latents + chunk masks) with hidden states
hidden_states = torch.cat([context_latents, hidden_states], dim=-1)
# Record original sequence length for later restoration after padding
original_seq_len = hidden_states.shape[1]
# Apply padding if sequence length is not divisible by patch_size
# This ensures proper patch extraction
pad_length = 0
if hidden_states.shape[1] % self.patch_size != 0:
pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size)
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length), mode='constant', value=0)
# Project input to patches and project encoder states
hidden_states = self.proj_in(hidden_states)
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
# Cache positions
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
)
# Position IDs
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
seq_len = hidden_states.shape[1]
encoder_seq_len = encoder_hidden_states.shape[1]
dtype = hidden_states.dtype
device = hidden_states.device
# Initialize Mask variables
full_attn_mask = None
sliding_attn_mask = None
encoder_attn_mask = None
decoder_attn_mask = None
# Target library discards the passed-in attention_mask for 4D mask
# construction (line 1384: attention_mask = None)
attention_mask = None
# 1. Full Attention (Bidirectional, Global)
full_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask,
sliding_window=None,
is_sliding_window=False,
is_causal=False
)
max_len = max(seq_len, encoder_seq_len)
encoder_attn_mask = create_4d_mask(
seq_len=max_len,
dtype=dtype,
device=device,
attention_mask=attention_mask,
sliding_window=None,
is_sliding_window=False,
is_causal=False
)
encoder_attn_mask = encoder_attn_mask[:, :, :seq_len, :encoder_seq_len]
# 2. Sliding Attention (Bidirectional, Local)
if self.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask,
sliding_window=self.sliding_window,
is_sliding_window=True,
is_causal=False
)
# Build mask mapping
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
"encoder_attention_mask": encoder_attn_mask,
}
# Create position embeddings to be shared across all decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
all_cross_attentions = () if output_attentions else None
# Handle early exit for custom layer configurations
max_needed_layer = float('inf')
if custom_layers_config is not None and enable_early_exit:
max_needed_layer = max(custom_layers_config.keys())
output_attentions = True
if all_cross_attentions is None:
all_cross_attentions = ()
# Process through transformer layers
for index_block, layer_module in enumerate(self.layers):
# Early exit optimization
if index_block > max_needed_layer:
break
# Prepare layer arguments
layer_args = (
hidden_states,
position_embeddings,
timestep_proj,
self_attn_mask_mapping[layer_module.attention_type],
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
encoder_hidden_states,
self_attn_mask_mapping["encoder_attention_mask"],
)
layer_kwargs = flash_attn_kwargs
# Use gradient checkpointing if enabled
layer_outputs = gradient_checkpoint_forward(
layer_module,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*layer_args,
**layer_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions and self.layers[index_block].use_cross_attention:
# layer_outputs structure: (hidden_states, self_attn_weights, cross_attn_weights)
if len(layer_outputs) >= 3:
all_cross_attentions += (layer_outputs[2],)
if return_hidden_states:
return hidden_states
# Extract scale-shift parameters for adaptive output normalization
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
# Apply adaptive layer norm: norm(x) * (1 + scale) + shift
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
# Project output: de-patchify back to original sequence format
hidden_states = self.proj_out(hidden_states)
# Crop back to original sequence length to ensure exact length match (remove padding)
hidden_states = hidden_states[:, :original_seq_len, :]
outputs = (hidden_states, past_key_values)
if output_attentions:
outputs += (all_cross_attentions,)
return outputs

View File

@@ -0,0 +1,53 @@
import torch
class AceStepTextEncoder(torch.nn.Module):
def __init__(
self,
):
super().__init__()
from transformers import Qwen3Config, Qwen3Model
config = Qwen3Config(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=151643,
dtype="bfloat16",
eos_token_id=151643,
head_dim=128,
hidden_act="silu",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=3072,
layer_types=["full_attention"] * 28,
max_position_embeddings=32768,
max_window_layers=28,
model_type="qwen3",
num_attention_heads=16,
num_hidden_layers=28,
num_key_value_heads=8,
pad_token_id=151643,
rms_norm_eps=1e-06,
rope_scaling=None,
rope_theta=1000000,
sliding_window=None,
tie_word_embeddings=True,
use_cache=True,
use_sliding_window=False,
vocab_size=151669,
)
self.model = Qwen3Model(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
return outputs.last_hidden_state

View File

@@ -0,0 +1,732 @@
# Copyright 2025 The ACESTEO Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ACE-Step Audio Tokenizer — VAE latent discretization pathway.
Contains:
- AceStepAudioTokenizer: continuous VAE latent → discrete FSQ tokens
- AudioTokenDetokenizer: discrete tokens → continuous VAE-latent-shaped features
Only used in cover song mode (is_covers=True). Bypassed in text-to-music.
"""
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from ..core.attention import attention_forward
from ..core.gradient import gradient_checkpoint_forward
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutput
from transformers.processing_utils import Unpack
from transformers.utils import can_return_tuple, logging
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3MLP,
Qwen3RMSNorm,
Qwen3RotaryEmbedding,
apply_rotary_pos_emb,
)
from vector_quantize_pytorch import ResidualFSQ
logger = logging.get_logger(__name__)
def create_4d_mask(
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
is_sliding_window: bool = False,
is_causal: bool = True,
) -> torch.Tensor:
indices = torch.arange(seq_len, device=device)
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
if is_causal:
valid_mask = valid_mask & (diff >= 0)
if is_sliding_window and sliding_window is not None:
if is_causal:
valid_mask = valid_mask & (diff <= sliding_window)
else:
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
valid_mask = valid_mask & padding_mask_4d
min_dtype = torch.finfo(dtype).min
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
mask_tensor.masked_fill_(valid_mask, 0.0)
return mask_tensor
class Lambda(nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
class AceStepAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
is_cross_attention: bool = False,
is_causal: bool = False,
):
super().__init__()
self.layer_idx = layer_idx
self.head_dim = head_dim or hidden_size // num_attention_heads
self.num_key_value_groups = num_attention_heads // num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.attention_dropout = attention_dropout
if is_cross_attention:
is_causal = False
self.is_causal = is_causal
self.is_cross_attention = is_cross_attention
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.attention_type = layer_types[layer_idx]
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
if is_cross_attention:
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
curr_past_key_value = past_key_value.cross_attention_cache
if not is_updated:
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
past_key_value.is_updated[self.layer_idx] = True
else:
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
else:
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if self.num_key_value_groups > 1:
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
attn_output = attention_forward(
query_states, key_states, value_states,
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
attn_mask=attention_mask,
)
attn_weights = None
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class AceStepEncoderLayer(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.layer_idx = layer_idx
self.self_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
is_cross_attention=False,
is_causal=False,
)
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
mlp_config = type('Config', (), {
'hidden_size': hidden_size,
'intermediate_size': intermediate_size,
'hidden_act': 'silu',
})()
self.mlp = Qwen3MLP(mlp_config)
self.attention_type = layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=False,
past_key_value=None,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class AttentionPooler(nn.Module):
"""Pools every pool_window_size frames into 1 representation via transformer + CLS token."""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
num_attention_pooler_hidden_layers: int = 2,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Default matches target library config (24 alternating entries).
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# Slice layer_types to our own layer count
pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
rope_config = type('RopeConfig', (), {
'hidden_size': hidden_size,
'num_attention_heads': num_attention_heads,
'num_key_value_heads': num_key_value_heads,
'head_dim': head_dim,
'max_position_embeddings': max_position_embeddings,
'rope_theta': rope_theta,
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
'rms_norm_eps': rms_norm_eps,
'attention_bias': attention_bias,
'attention_dropout': attention_dropout,
'hidden_act': 'silu',
'intermediate_size': intermediate_size,
'layer_types': pooler_layer_types,
'sliding_window': sliding_window,
'_attn_implementation': self._attn_implementation,
})()
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
self.gradient_checkpointing = False
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
self.layers = nn.ModuleList([
AceStepEncoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=pooler_layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
for layer_idx in range(num_attention_pooler_hidden_layers)
])
@can_return_tuple
def forward(
self,
x,
attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> torch.Tensor:
B, T, P, D = x.shape
x = self.embed_tokens(x)
special_tokens = self.special_token.expand(B, T, 1, -1).to(x.device)
x = torch.cat([special_tokens, x], dim=2)
x = rearrange(x, "b t p c -> (b t) p c")
cache_position = torch.arange(0, x.shape[1], device=x.device)
position_ids = cache_position.unsqueeze(0)
hidden_states = x
position_embeddings = self.rotary_emb(hidden_states, position_ids)
seq_len = x.shape[1]
dtype = x.dtype
device = x.device
full_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=None,
is_sliding_window=False, is_causal=False
)
sliding_attn_mask = None
if self.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=self.sliding_window,
is_sliding_window=True, is_causal=False
)
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
for layer_module in self.layers:
layer_outputs = layer_module(
hidden_states, position_embeddings,
attention_mask=self_attn_mask_mapping[layer_module.attention_type],
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
cls_output = hidden_states[:, 0, :]
return rearrange(cls_output, "(b t) c -> b t c", b=B)
class AceStepAudioTokenizer(nn.Module):
"""Converts continuous acoustic features (VAE latents) into discrete quantized tokens.
Input: [B, T, 64] (VAE latent dim)
Output: quantized [B, T/5, 2048], indices [B, T/5, 1]
"""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
audio_acoustic_hidden_dim: int = 64,
pool_window_size: int = 5,
fsq_dim: int = 2048,
fsq_input_levels: list = None,
fsq_input_num_quantizers: int = 1,
num_attention_pooler_hidden_layers: int = 2,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Default matches target library config (24 alternating entries).
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
self.pool_window_size = pool_window_size
self.fsq_dim = fsq_dim
self.fsq_input_levels = fsq_input_levels or [8, 8, 8, 5, 5, 5]
self.fsq_input_num_quantizers = fsq_input_num_quantizers
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.audio_acoustic_proj = nn.Linear(audio_acoustic_hidden_dim, hidden_size)
# Slice layer_types for the attention pooler
pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
self.attention_pooler = AttentionPooler(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=pooler_layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
use_sliding_window=use_sliding_window,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
)
self.quantizer = ResidualFSQ(
dim=self.fsq_dim,
levels=self.fsq_input_levels,
num_quantizers=self.fsq_input_num_quantizers,
force_quantization_f32=False, # avoid autocast bug in vector_quantize_pytorch
)
@can_return_tuple
def forward(
self,
hidden_states: Optional[torch.FloatTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.audio_acoustic_proj(hidden_states)
hidden_states = self.attention_pooler(hidden_states)
quantized, indices = self.quantizer(hidden_states)
return quantized, indices
def tokenize(self, x):
"""Convenience: takes [B, T, 64], rearranges to patches, runs forward."""
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=self.pool_window_size)
return self.forward(x)
class AudioTokenDetokenizer(nn.Module):
"""Converts quantized audio tokens back to continuous acoustic representations.
Input: [B, T/5, hidden_size] (quantized vectors)
Output: [B, T, 64] (VAE-latent-shaped continuous features)
"""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
pool_window_size: int = 5,
audio_acoustic_hidden_dim: int = 64,
num_attention_pooler_hidden_layers: int = 2,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Default matches target library config (24 alternating entries).
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.pool_window_size = pool_window_size
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# Slice layer_types to our own layer count (use num_audio_decoder_hidden_layers)
detok_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
rope_config = type('RopeConfig', (), {
'hidden_size': hidden_size,
'num_attention_heads': num_attention_heads,
'num_key_value_heads': num_key_value_heads,
'head_dim': head_dim,
'max_position_embeddings': max_position_embeddings,
'rope_theta': rope_theta,
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
'rms_norm_eps': rms_norm_eps,
'attention_bias': attention_bias,
'attention_dropout': attention_dropout,
'hidden_act': 'silu',
'intermediate_size': intermediate_size,
'layer_types': detok_layer_types,
'sliding_window': sliding_window,
'_attn_implementation': self._attn_implementation,
})()
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
self.gradient_checkpointing = False
self.special_tokens = nn.Parameter(torch.randn(1, pool_window_size, hidden_size) * 0.02)
self.layers = nn.ModuleList([
AceStepEncoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=detok_layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
for layer_idx in range(num_attention_pooler_hidden_layers)
])
self.proj_out = nn.Linear(hidden_size, audio_acoustic_hidden_dim)
@can_return_tuple
def forward(
self,
x,
attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> torch.Tensor:
B, T, D = x.shape
x = self.embed_tokens(x)
x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
special_tokens = self.special_tokens.expand(B, T, -1, -1)
x = x + special_tokens.to(x.device)
x = rearrange(x, "b t p c -> (b t) p c")
cache_position = torch.arange(0, x.shape[1], device=x.device)
position_ids = cache_position.unsqueeze(0)
hidden_states = x
position_embeddings = self.rotary_emb(hidden_states, position_ids)
seq_len = x.shape[1]
dtype = x.dtype
device = x.device
full_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=None,
is_sliding_window=False, is_causal=False
)
sliding_attn_mask = None
if self.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=self.sliding_window,
is_sliding_window=True, is_causal=False
)
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
for layer_module in self.layers:
layer_outputs = layer_module(
hidden_states, position_embeddings,
attention_mask=self_attn_mask_mapping[layer_module.attention_type],
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_out(hidden_states)
return rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.pool_window_size)
class AceStepTokenizer(nn.Module):
"""Container for AceStepAudioTokenizer + AudioTokenDetokenizer.
Provides encode/decode convenience methods for VAE latent discretization.
Used in cover song mode to convert source audio latents to discrete tokens
and back to continuous conditioning hints.
"""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
audio_acoustic_hidden_dim: int = 64,
pool_window_size: int = 5,
fsq_dim: int = 2048,
fsq_input_levels: list = None,
fsq_input_num_quantizers: int = 1,
num_attention_pooler_hidden_layers: int = 2,
num_audio_decoder_hidden_layers: int = 24,
**kwargs,
):
super().__init__()
# Default layer_types matches target library config (24 alternating entries).
# Sub-modules (pooler/detokenizer) slice first N entries for their own layer count.
if layer_types is None:
layer_types = ["sliding_attention", "full_attention"] * 12
self.tokenizer = AceStepAudioTokenizer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
use_sliding_window=use_sliding_window,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
audio_acoustic_hidden_dim=audio_acoustic_hidden_dim,
pool_window_size=pool_window_size,
fsq_dim=fsq_dim,
fsq_input_levels=fsq_input_levels,
fsq_input_num_quantizers=fsq_input_num_quantizers,
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
**kwargs,
)
self.detokenizer = AudioTokenDetokenizer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
use_sliding_window=use_sliding_window,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
pool_window_size=pool_window_size,
audio_acoustic_hidden_dim=audio_acoustic_hidden_dim,
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
**kwargs,
)
def encode(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""VAE latent [B, T, 64] → discrete tokens."""
return self.tokenizer(hidden_states)
def decode(self, quantized: torch.Tensor) -> torch.Tensor:
"""Discrete tokens [B, T/5, hidden_size] → continuous [B, T, 64]."""
return self.detokenizer(quantized)
def tokenize(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convenience: [B, T, 64] → quantized + indices via patch rearrangement."""
return self.tokenizer.tokenize(x)

View File

@@ -0,0 +1,287 @@
# Copyright 2025 The ACESTEO Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ACE-Step Audio VAE (AutoencoderOobleck CNN architecture).
This is a CNN-based VAE for audio waveform encoding/decoding.
It uses weight-normalized convolutions and Snake1d activations.
Does NOT depend on diffusers — pure nn.Module implementation.
"""
import math
from typing import Optional
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm, remove_weight_norm
class Snake1d(nn.Module):
"""Snake activation: x + 1/(beta+eps) * sin(alpha*x)^2."""
def __init__(self, hidden_dim: int, logscale: bool = True):
super().__init__()
self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.logscale = logscale
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
shape = hidden_states.shape
alpha = torch.exp(self.alpha) if self.logscale else self.alpha
beta = torch.exp(self.beta) if self.logscale else self.beta
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
return hidden_states.reshape(shape)
class OobleckResidualUnit(nn.Module):
"""Residual unit: Snake1d → Conv1d(dilated) → Snake1d → Conv1d(1×1) + skip."""
def __init__(self, dimension: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.snake1 = Snake1d(dimension)
self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
self.snake2 = Snake1d(dimension)
self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
output = self.conv1(self.snake1(hidden_state))
output = self.conv2(self.snake2(output))
padding = (hidden_state.shape[-1] - output.shape[-1]) // 2
if padding > 0:
hidden_state = hidden_state[..., padding:-padding]
return hidden_state + output
class OobleckEncoderBlock(nn.Module):
"""Encoder block: 3 residual units + downsampling conv."""
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
super().__init__()
self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
self.snake1 = Snake1d(input_dim)
self.conv1 = weight_norm(
nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.snake1(self.res_unit3(hidden_state))
return self.conv1(hidden_state)
class OobleckDecoderBlock(nn.Module):
"""Decoder block: upsampling conv + 3 residual units."""
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
super().__init__()
self.snake1 = Snake1d(input_dim)
self.conv_t1 = weight_norm(
nn.ConvTranspose1d(
input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2),
)
)
self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv_t1(hidden_state)
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
return self.res_unit3(hidden_state)
class OobleckEncoder(nn.Module):
"""Full encoder: audio → latent representation [B, encoder_hidden_size, T'].
conv1 → [blocks] → snake1 → conv2
"""
def __init__(
self,
encoder_hidden_size: int = 128,
audio_channels: int = 2,
downsampling_ratios: list = None,
channel_multiples: list = None,
):
super().__init__()
downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10]
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
channel_multiples = [1] + channel_multiples
self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
self.block = nn.ModuleList()
for stride_index, stride in enumerate(downsampling_ratios):
self.block.append(
OobleckEncoderBlock(
input_dim=encoder_hidden_size * channel_multiples[stride_index],
output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
stride=stride,
)
)
d_model = encoder_hidden_size * channel_multiples[-1]
self.snake1 = Snake1d(d_model)
self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.conv1(hidden_state)
for block in self.block:
hidden_state = block(hidden_state)
hidden_state = self.snake1(hidden_state)
return self.conv2(hidden_state)
class OobleckDecoder(nn.Module):
"""Full decoder: latent → audio waveform [B, audio_channels, T].
conv1 → [blocks] → snake1 → conv2(no bias)
"""
def __init__(
self,
channels: int = 128,
input_channels: int = 64,
audio_channels: int = 2,
upsampling_ratios: list = None,
channel_multiples: list = None,
):
super().__init__()
upsampling_ratios = upsampling_ratios or [10, 6, 4, 4, 2]
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
channel_multiples = [1] + channel_multiples
self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
self.block = nn.ModuleList()
for stride_index, stride in enumerate(upsampling_ratios):
self.block.append(
OobleckDecoderBlock(
input_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index],
output_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index - 1],
stride=stride,
)
)
self.snake1 = Snake1d(channels)
# conv2 has no bias (matches checkpoint: only weight_g/weight_v, no bias key)
self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.conv1(hidden_state)
for block in self.block:
hidden_state = block(hidden_state)
hidden_state = self.snake1(hidden_state)
return self.conv2(hidden_state)
class OobleckDiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.scale = parameters.chunk(2, dim=1)
self.std = nn.functional.softplus(self.scale) + 1e-4
self.var = self.std * self.std
self.logvar = torch.log(self.var)
self.deterministic = deterministic
def sample(self, generator: torch.Generator | None = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = torch.randn(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
else:
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
var_ratio = self.var / other.var
logvar_diff = self.logvar - other.logvar
kl = normalized_diff + var_ratio + logvar_diff - 1
kl = kl.sum(1).mean()
return kl
class AceStepVAE(nn.Module):
"""Audio VAE for ACE-Step (AutoencoderOobleck architecture).
Encodes audio waveform → latent, decodes latent → audio waveform.
Uses Snake1d activations and weight-normalized convolutions.
"""
def __init__(
self,
encoder_hidden_size: int = 128,
downsampling_ratios: list = None,
channel_multiples: list = None,
decoder_channels: int = 128,
decoder_input_channels: int = 64,
audio_channels: int = 2,
sampling_rate: int = 48000,
):
super().__init__()
downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10]
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
upsampling_ratios = downsampling_ratios[::-1]
self.encoder = OobleckEncoder(
encoder_hidden_size=encoder_hidden_size,
audio_channels=audio_channels,
downsampling_ratios=downsampling_ratios,
channel_multiples=channel_multiples,
)
self.decoder = OobleckDecoder(
channels=decoder_channels,
input_channels=decoder_input_channels,
audio_channels=audio_channels,
upsampling_ratios=upsampling_ratios,
channel_multiples=channel_multiples,
)
self.sampling_rate = sampling_rate
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Audio waveform [B, audio_channels, T] → latent [B, decoder_input_channels, T']."""
h = self.encoder(x)
output = OobleckDiagonalGaussianDistribution(h).sample()
return output
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Latent [B, decoder_input_channels, T] → audio waveform [B, audio_channels, T']."""
return self.decoder(z)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
"""Full round-trip: encode → decode."""
z = self.encode(sample)
return self.decode(z)
def remove_weight_norm(self):
"""Remove weight normalization from all conv layers (for export/inference)."""
for module in self.modules():
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
remove_weight_norm(module)

View File

@@ -0,0 +1,636 @@
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ..core.attention import attention_forward
from ..core.gradient import gradient_checkpoint_forward
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
) -> torch.Tensor:
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
emb = scale * emb
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
return get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = nn.SiLU()
time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
self.post_act = nn.SiLU() if post_act_fn == "silu" else None
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class PixArtAlphaTextProjection(nn.Module):
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj(hidden_states)
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
return hidden_states
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
# Build activation + projection matching diffusers pattern
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
elif activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
else:
act_fn = GELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
self.net.append(act_fn)
self.net.append(nn.Dropout(dropout))
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
if len(args) == 0:
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = _to_tuple(args[1], dim=dim)
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij")
grid = torch.stack(grid, dim=0)
return grid
def reshape_for_broadcast(freqs_cis, x, head_first=False):
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
if head_first:
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1])
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
if head_first:
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(xq, xk, freqs_cis, head_first=False):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
return xq_out, xk_out
def get_1d_rotary_pos_embed(dim, pos, theta=10000.0, use_real=False, theta_rescale_factor=1.0, interpolation_factor=1.0):
if isinstance(pos, int):
pos = torch.arange(pos).float()
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = torch.outer(pos * interpolation_factor, freqs)
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1)
freqs_sin = freqs.sin().repeat_interleave(2, dim=1)
return freqs_cos, freqs_sin
else:
return torch.polar(torch.ones_like(freqs), freqs)
def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000.0, use_real=False,
txt_rope_size=None, theta_rescale_factor=1.0, interpolation_factor=1.0):
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
if isinstance(theta_rescale_factor, (int, float)):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
if isinstance(interpolation_factor, (int, float)):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i], grid[i].reshape(-1), theta,
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
)
embs.append(emb)
if use_real:
vis_emb = (torch.cat([emb[0] for emb in embs], dim=1), torch.cat([emb[1] for emb in embs], dim=1))
else:
vis_emb = torch.cat(embs, dim=1)
if txt_rope_size is not None:
embs_txt = []
vis_max_ids = grid.view(-1).max().item()
grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i], grid_txt, theta,
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
)
embs_txt.append(emb)
if use_real:
txt_emb = (torch.cat([emb[0] for emb in embs_txt], dim=1), torch.cat([emb[1] for emb in embs_txt], dim=1))
else:
txt_emb = torch.cat(embs_txt, dim=1)
else:
txt_emb = None
return vis_emb, txt_emb
class ModulateWan(nn.Module):
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
super().__init__()
self.factor = factor
factory_kwargs = {"dtype": dtype, "device": device}
self.modulate_table = nn.Parameter(
torch.zeros(1, factor, hidden_size, **factory_kwargs) / hidden_size**0.5,
requires_grad=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape) != 3:
x = x.unsqueeze(1)
return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)]
def modulate(x, shift=None, scale=None):
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def load_modulation(modulate_type: str, hidden_size: int, factor: int, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
if modulate_type == 'wanx':
return ModulateWan(hidden_size, factor, **factory_kwargs)
raise ValueError(f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.")
class RMSNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with separate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
dit_modulation_type: Optional[str] = "wanx",
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.dit_modulation_type = dit_modulation_type
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = load_modulation(
modulate_type=self.dit_modulation_type,
hidden_size=hidden_size, factor=6, **factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
self.txt_mod = load_modulation(
modulate_type=self.dit_modulation_type,
hidden_size=hidden_size, factor=6, **factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
vis_freqs_cis: tuple = None,
txt_freqs_cis: tuple = None,
attn_kwargs: Optional[dict] = {},
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift, img_mod1_scale, img_mod1_gate,
img_mod2_shift, img_mod2_scale, img_mod2_gate,
) = self.img_mod(vec)
(
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate,
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate,
) = self.txt_mod(vec)
img_modulated = self.img_norm1(img)
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
if vis_freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
img_q, img_k = img_qq, img_kk
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
if txt_freqs_cis is not None:
raise NotImplementedError("RoPE text is not supported for inference")
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
# Use DiffSynth unified attention
attn_out = attention_forward(
q, k, v,
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
)
attn_out = attn_out.flatten(2, 3)
img_attn, txt_attn = attn_out[:, : img.shape[1]], attn_out[:, img.shape[1]:]
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
gate=img_mod2_gate,
)
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
gate=txt_mod2_gate,
)
return img, txt
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
return temb, timestep_proj, encoder_hidden_states
class JoyAIImageDiT(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
patch_size: list = [1, 2, 2],
in_channels: int = 16,
out_channels: int = 16,
hidden_size: int = 4096,
heads_num: int = 32,
text_states_dim: int = 4096,
mlp_width_ratio: float = 4.0,
mm_double_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
rope_type: str = 'rope',
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
dit_modulation_type: str = "wanx",
theta: int = 10000,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.patch_size = patch_size
self.hidden_size = hidden_size
self.heads_num = heads_num
self.rope_dim_list = rope_dim_list
self.dit_modulation_type = dit_modulation_type
self.mm_double_blocks_depth = mm_double_blocks_depth
self.rope_type = rope_type
self.theta = theta
factory_kwargs = {"device": device, "dtype": dtype}
if hidden_size % heads_num != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
self.condition_embedder = WanTimeTextImageEmbedding(
dim=hidden_size,
time_freq_dim=256,
time_proj_dim=hidden_size * 6,
text_embed_dim=text_states_dim,
)
self.double_blocks = nn.ModuleList([
MMDoubleStreamBlock(
self.hidden_size, self.heads_num,
mlp_width_ratio=mlp_width_ratio,
dit_modulation_type=self.dit_modulation_type,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
])
self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size), **factory_kwargs)
def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None):
target_ndim = 3
if len(vis_rope_size) != target_ndim:
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
head_dim = self.hidden_size // self.heads_num
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim
vis_freqs, txt_freqs = get_nd_rotary_pos_embed(
rope_dim_list, vis_rope_size,
txt_rope_size=txt_rope_size if self.rope_type == 'mrope' else None,
theta=self.theta, use_real=True, theta_rescale_factor=1,
)
return vis_freqs, txt_freqs
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_mask: torch.Tensor = None,
return_dict: bool = True,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
is_multi_item = (len(hidden_states.shape) == 6)
num_items = 0
if is_multi_item:
num_items = hidden_states.shape[1]
if num_items > 1:
assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1"
hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1)
hidden_states = rearrange(hidden_states, 'b n c t h w -> b c (n t) h w')
batch_size, _, ot, oh, ow = hidden_states.shape
tt, th, tw = ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2]
if encoder_hidden_states_mask is None:
encoder_hidden_states_mask = torch.ones(
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]),
dtype=torch.bool,
).to(encoder_hidden_states.device)
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
temb, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
if vec.shape[-1] > self.hidden_size:
vec = vec.unflatten(1, (6, -1))
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed(
vis_rope_size=(tt, th, tw),
txt_rope_size=txt_seq_len if self.rope_type == 'mrope' else None,
)
for block in self.double_blocks:
img, txt = gradient_checkpoint_forward(
block,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
img=img, txt=txt, vec=vec,
vis_freqs_cis=vis_freqs_cis, txt_freqs_cis=txt_freqs_cis,
attn_kwargs={},
)
img_len = img.shape[1]
x = torch.cat((img, txt), 1)
img = x[:, :img_len, ...]
img = self.proj_out(self.norm_out(img))
img = self.unpatchify(img, tt, th, tw)
if is_multi_item:
img = rearrange(img, 'b c (n t) h w -> b n c t h w', n=num_items)
if num_items > 1:
img = torch.cat([img[:, 1:], img[:, :1]], dim=1)
return img
def unpatchify(self, x, t, h, w):
c = self.out_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
x = torch.einsum("nthwopqc->nctohpwq", x)
return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))

View File

@@ -0,0 +1,82 @@
import torch
from typing import Optional
class JoyAIImageTextEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
config = Qwen3VLConfig(
text_config={
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12288,
"max_position_embeddings": 262144,
"model_type": "qwen3_vl_text",
"num_attention_heads": 32,
"num_hidden_layers": 36,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-6,
"rope_scaling": {
"mrope_interleaved": True,
"mrope_section": [24, 20, 20],
"rope_type": "default",
},
"rope_theta": 5000000,
"use_cache": True,
"vocab_size": 151936,
},
vision_config={
"deepstack_visual_indexes": [8, 16, 24],
"depth": 27,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"in_channels": 3,
"initializer_range": 0.02,
"intermediate_size": 4304,
"model_type": "qwen3_vl",
"num_heads": 16,
"num_position_embeddings": 2304,
"out_hidden_size": 4096,
"patch_size": 16,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
},
image_token_id=151655,
video_token_id=151656,
vision_start_token_id=151652,
vision_end_token_id=151653,
tie_word_embeddings=False,
)
self.model = Qwen3VLForConditionalGeneration(config)
self.config = config
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
**kwargs,
):
pre_norm_output = [None]
def hook_fn(module, args, kwargs_output=None):
pre_norm_output[0] = args[0]
self.model.model.language_model.norm.register_forward_hook(hook_fn)
_ = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
return pre_norm_output[0]

View File

@@ -0,0 +1,584 @@
"""
ACE-Step Pipeline for DiffSynth-Studio.
Text-to-Music generation pipeline using ACE-Step 1.5 model.
"""
import re, torch
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
import random, math
import torch.nn.functional as F
from einops import rearrange
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.ace_step_dit import AceStepDiTModel
from ..models.ace_step_conditioner import AceStepConditionEncoder
from ..models.ace_step_text_encoder import AceStepTextEncoder
from ..models.ace_step_vae import AceStepVAE
from ..models.ace_step_tokenizer import AceStepTokenizer
class AceStepPipeline(BasePipeline):
"""Pipeline for ACE-Step text-to-music generation."""
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device,
torch_dtype=torch_dtype,
height_division_factor=1,
width_division_factor=1,
)
self.scheduler = FlowMatchScheduler("ACE-Step")
self.text_encoder: AceStepTextEncoder = None
self.conditioner: AceStepConditionEncoder = None
self.dit: AceStepDiTModel = None
self.vae: AceStepVAE = None
self.tokenizer_model: AceStepTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
AceStepUnit_TaskTypeChecker(),
AceStepUnit_PromptEmbedder(),
AceStepUnit_ReferenceAudioEmbedder(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_ConditionEmbedder(),
AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(),
]
self.model_fn = model_fn_ace_step
self.compilable_models = ["dit"]
self.sample_rate = 48000
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: str = get_device_type(),
model_configs: list[ModelConfig] = [],
text_tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
vram_limit: float = None,
):
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("ace_step_text_encoder")
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
pipe.dit = model_pool.fetch_model("ace_step_dit")
pipe.vae = model_pool.fetch_model("ace_step_vae")
pipe.vae.remove_weight_norm()
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
if text_tokenizer_config is not None:
text_tokenizer_config.download_if_necessary()
from transformers import AutoTokenizer
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
if silence_latent_config is not None:
silence_latent_config.download_if_necessary()
pipe.silence_latent = torch.load(silence_latent_config.path, weights_only=True).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
cfg_scale: float = 1.0,
# Lyrics
lyrics: str = "",
# Task type
task_type: Optional[str] = "text2music",
# Reference audio
reference_audios: List[torch.Tensor] = None,
# Source audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
audio_cover_strength: float = 1.0,
# Audio codes
audio_code_string: Optional[str] = None,
# Inpainting
repainting_ranges: Optional[List[Tuple[float, float]]] = None,
repainting_strength: float = 1.0,
# Shape
duration: int = 60,
# Audio Meta
bpm: Optional[int] = 100,
keyscale: Optional[str] = "B minor",
timesignature: Optional[str] = "4",
vocal_language: Optional[str] = "unknown",
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 8,
# Scheduler-specific parameters
shift: float = 1.0,
# Progress
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
# Parameters
inputs_posi = {"prompt": prompt, "positive": True}
inputs_nega = {"positive": False}
inputs_shared = {
"cfg_scale": cfg_scale,
"lyrics": lyrics,
"task_type": task_type,
"reference_audios": reference_audios,
"src_audio": src_audio, "audio_cover_strength": audio_cover_strength, "audio_code_string": audio_code_string,
"repainting_ranges": repainting_ranges, "repainting_strength": repainting_strength,
"duration": duration,
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
"seed": seed,
"rand_device": rand_device,
"num_inference_steps": num_inference_steps,
"shift": shift,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
self.switch_noncover_condition(inputs_shared, inputs_posi, inputs_nega, progress_id)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id,
)
inputs_shared["latents"] = self.step(
self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared,
)
# Decode
self.load_models_to_device(['vae'])
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
latents = inputs_shared["latents"].transpose(1, 2)
vae_output = self.vae.decode(latents)
audio_output = self.normalize_audio(vae_output, target_db=-1.0)
audio = self.output_audio_format_check(audio_output)
self.load_models_to_device([])
return audio
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
peak = torch.max(torch.abs(audio))
if peak < 1e-6:
return audio
target_amp = 10 ** (target_db / 20.0)
gain = target_amp / peak
return audio * gain
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0:
return
if inputs_shared.get("shared_noncover", None) is None:
return
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
if progress_id >= cover_steps:
inputs_shared.update(inputs_shared.pop("shared_noncover", {}))
inputs_posi.update(inputs_shared.pop("posi_noncover", {}))
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega.update(inputs_shared.pop("nega_noncover", {}))
class AceStepUnit_TaskTypeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
input_params=("task_type", "src_audio", "repainting_ranges", "audio_code_string"),
output_params=("task_type",),
)
def process(self, pipe, task_type, src_audio, repainting_ranges, audio_code_string):
assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
if task_type == "cover":
assert (src_audio is not None) or (audio_code_string is not None), "For cover task, either src_audio or audio_code_string must be provided."
elif task_type == "repaint":
assert src_audio is not None, "For repaint task, src_audio must be provided."
assert repainting_ranges is not None and len(repainting_ranges) > 0, "For repaint task, inpainting_ranges must be provided and non-empty."
return {}
class AceStepUnit_PromptEmbedder(PipelineUnit):
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
INSTRUCTION_MAP = {
"text2music": "Fill the audio semantic mask based on the given conditions:",
"cover": "Generate audio semantic tokens based on the given conditions:",
"repaint": "Repaint the mask area based on the given conditions:",
"extract": "Extract the {TRACK_NAME} track from the audio:",
"extract_default": "Extract the track from the audio:",
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
"lego_default": "Generate the track based on the audio context:",
"complete": "Complete the input track with {TRACK_CLASSES}:",
"complete_default": "Complete the input track:",
}
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt", "positive": "positive"},
input_params_nega={"prompt": "prompt", "positive": "positive"},
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language", "task_type"),
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
onload_model_names=("text_encoder",)
)
def _encode_text(self, pipe, text, max_length=256):
"""Encode text using Qwen3-Embedding → [B, T, 1024]."""
text_inputs = pipe.tokenizer(
text,
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(pipe.device)
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
hidden_states = pipe.text_encoder(input_ids, attention_mask)
return hidden_states, attention_mask
def _encode_lyrics(self, pipe, lyric_text, max_length=2048):
text_inputs = pipe.tokenizer(
lyric_text,
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(pipe.device)
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
hidden_states = pipe.text_encoder.model.embed_tokens(input_ids)
return hidden_states, attention_mask
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
bpm = meta_dict.get("bpm", "N/A")
timesignature = meta_dict.get("timesignature", "N/A")
keyscale = meta_dict.get("keyscale", "N/A")
duration = meta_dict.get("duration", 30)
duration = f"{int(duration)} seconds"
return (
f"- bpm: {bpm}\n"
f"- timesignature: {timesignature}\n"
f"- keyscale: {keyscale}\n"
f"- duration: {duration}\n"
)
def process(self, pipe, prompt, positive, lyrics, duration, bpm, keyscale, timesignature, vocal_language, task_type):
if not positive:
return {}
pipe.load_models_to_device(['text_encoder'])
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
INSTRUCTION = self.INSTRUCTION_MAP.get(task_type, self.INSTRUCTION_MAP["text2music"])
prompt = self.SFT_GEN_PROMPT.format(INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
lyric_hidden_states, lyric_attention_mask = self._encode_lyrics(pipe, lyric_text, max_length=2048)
# TODO: remove this
newtext = prompt + "\n\n" + lyric_text
return {
"text_hidden_states": text_hidden_states,
"text_attention_mask": text_attention_mask,
"lyric_hidden_states": lyric_hidden_states,
"lyric_attention_mask": lyric_attention_mask,
}
class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("reference_audios",),
output_params=("reference_latents", "refer_audio_order_mask"),
onload_model_names=("vae",)
)
def process(self, pipe, reference_audios):
if reference_audios is not None:
pipe.load_models_to_device(['vae'])
reference_audios = [
self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device)
for reference_audio in reference_audios
]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
else:
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
def process_reference_audio(self, audio) -> Optional[torch.Tensor]:
if audio.ndim == 3 and audio.shape[0] == 1:
audio = audio.squeeze(0)
target_frames = 30 * 48000
segment_frames = 10 * 48000
if audio.shape[-1] < target_frames:
repeat_times = math.ceil(target_frames / audio.shape[-1])
audio = audio.repeat(1, repeat_times)
total_frames = audio.shape[-1]
segment_size = total_frames // 3
front_start = random.randint(0, max(0, segment_size - segment_frames))
front_audio = audio[:, front_start:front_start + segment_frames]
middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
middle_audio = audio[:, middle_start:middle_start + segment_frames]
back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
back_audio = audio[:, back_start:back_start + segment_frames]
return torch.cat([front_audio, middle_audio, back_audio], dim=-1).unsqueeze(0)
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Infer packed reference-audio latents and order mask."""
refer_audio_order_mask = []
refer_audio_latents = []
for batch_idx, refer_audios in enumerate(refer_audioss):
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
refer_audio_latent = pipe.silence_latent[:, :750, :]
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
else:
for refer_audio in refer_audios:
refer_audio_latent = pipe.vae.encode(refer_audio).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
return refer_audio_latents, refer_audio_order_mask
class AceStepUnit_ConditionEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
output_params=("encoder_hidden_states", "encoder_attention_mask"),
onload_model_names=("conditioner",),
)
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
pipe.load_models_to_device(['conditioner'])
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
text_hidden_states=inputs_posi.get("text_hidden_states", None),
text_attention_mask=inputs_posi.get("text_attention_mask", None),
lyric_hidden_states=inputs_posi.get("lyric_hidden_states", None),
lyric_attention_mask=inputs_posi.get("lyric_attention_mask", None),
reference_latents=inputs_shared.get("reference_latents", None),
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
)
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
pipe, inputs_posi["prompt"], True, inputs_shared["lyrics"], inputs_shared["duration"],
inputs_shared["bpm"], inputs_shared["keyscale"], inputs_shared["timesignature"],
inputs_shared["vocal_language"], "text2music")
encoder_hidden_states_noncover, encoder_attention_mask_noncover = pipe.conditioner(
**hidden_states_noncover,
reference_latents=inputs_shared.get("reference_latents", None),
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
)
duration = inputs_shared["context_latents"].shape[1] * 1920 / pipe.vae.sampling_rate
context_latents_noncover = AceStepUnit_ContextLatentBuilder().process(pipe, duration, None, None)["context_latents"]
inputs_shared["shared_noncover"] = {"context_latents": context_latents_noncover}
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
if inputs_shared["cfg_scale"] != 1.0:
inputs_shared["nega_noncover"] = {
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(
dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device,
),
"encoder_attention_mask": encoder_attention_mask_noncover,
}
return inputs_shared, inputs_posi, inputs_nega
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("duration", "src_audio", "audio_code_string", "task_type", "repainting_ranges", "repainting_strength"),
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
onload_model_names=("vae", "tokenizer_model",),
)
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
available = pipe.silence_latent.shape[1]
if length <= available:
return pipe.silence_latent[0, :length, :]
repeats = (length + available - 1) // available
tiled = pipe.silence_latent[0].repeat(repeats, 1)
return tiled[:length, :]
def tokenize(self, tokenizer, x, silence_latent, pool_window_size):
if x.shape[1] % pool_window_size != 0:
pad_len = pool_window_size - (x.shape[1] % pool_window_size)
x = torch.cat([x, silence_latent[:1,:pad_len].repeat(x.shape[0],1,1)], dim=1)
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=pool_window_size)
quantized, indices = tokenizer(x)
return quantized
@staticmethod
def _parse_audio_code_string(code_str: str) -> list:
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
if not code_str:
return []
try:
codes = []
max_audio_code = 63999
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
code_value = int(x)
codes.append(max(0, min(code_value, max_audio_code)))
except Exception as e:
raise ValueError(f"Invalid audio_code_string format: {e}")
return codes
def pad_src_audio(self, pipe, src_audio, task_type, repainting_ranges):
if task_type != "repaint" or repainting_ranges is None:
return src_audio, repainting_ranges, None, None
min_left = min([start for start, end in repainting_ranges])
max_right = max([end for start, end in repainting_ranges])
total_length = src_audio.shape[-1] // pipe.vae.sampling_rate
pad_left = max(0, -min_left)
pad_right = max(0, max_right - total_length)
if pad_left > 0 or pad_right > 0:
padding_frames_left, padding_frames_right = pad_left * pipe.vae.sampling_rate, pad_right * pipe.vae.sampling_rate
src_audio = F.pad(src_audio, (padding_frames_left, padding_frames_right), value=0.0)
repainting_ranges = [(start + pad_left, end + pad_left) for start, end in repainting_ranges]
return src_audio, repainting_ranges, pad_left, pad_right
def parse_repaint_masks(self, pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right):
if task_type != "repaint" or repainting_ranges is None:
return None, src_latents
# let repainting area be repainting_strength, non-repainting area be 0.0, and blend at the boundary with cf_frames.
max_latent_length = src_latents.shape[1]
denoise_mask = torch.zeros((1, max_latent_length, 1), dtype=pipe.torch_dtype, device=pipe.device)
for start, end in repainting_ranges:
start_frame = start * pipe.vae.sampling_rate // 1920
end_frame = end * pipe.vae.sampling_rate // 1920
denoise_mask[:, start_frame:end_frame, :] = repainting_strength
# set padding areas to 1.0 (full repaint) to avoid artifacts at the boundaries caused by padding
pad_left_frames = pad_left * pipe.vae.sampling_rate // 1920
pad_right_frames = pad_right * pipe.vae.sampling_rate // 1920
denoise_mask[:, :pad_left_frames, :] = 1
denoise_mask[:, max_latent_length - pad_right_frames:, :] = 1
silent_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
src_latents = src_latents * (1 - denoise_mask) + silent_latents * denoise_mask
return denoise_mask, src_latents
def process(self, pipe, duration, src_audio, audio_code_string, task_type=None, repainting_ranges=None, repainting_strength=None):
# get src_latents from audio_code_string > src_audio > silence
source_latents = None
denoise_mask = None
if audio_code_string is not None:
# use audio_cede_string to get src_latents.
pipe.load_models_to_device(self.onload_model_names)
code_ids = self._parse_audio_code_string(audio_code_string)
quantizer = pipe.tokenizer_model.tokenizer.quantizer.to(device=pipe.device)
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
codes = quantizer.get_codes_from_indices(indices)
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
quantized = quantizer.project_out(quantized)
src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
max_latent_length = src_latents.shape[1]
elif src_audio is not None:
# use src_audio to get src_latents.
pipe.load_models_to_device(self.onload_model_names)
src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
src_audio = torch.clamp(src_audio, -1.0, 1.0)
src_audio, repainting_ranges, pad_left, pad_right = self.pad_src_audio(pipe, src_audio, task_type, repainting_ranges)
src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
source_latents = src_latents # cache for potential use in audio inpainting tasks
denoise_mask, src_latents = self.parse_repaint_masks(pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right)
if task_type == "cover":
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
max_latent_length = src_latents.shape[1]
else:
# use silence latents.
max_latent_length = int(duration * pipe.sample_rate // 1920)
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
return {"context_latents": context_latents, "attention_mask": attention_mask, "src_latents": source_latents, "denoise_mask": denoise_mask}
class AceStepUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_latents", "seed", "rand_device", "src_latents"),
output_params=("noise",),
)
def process(self, pipe, context_latents, seed, rand_device, src_latents):
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
if src_latents is not None:
noise = pipe.scheduler.add_noise(src_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"noise": noise}
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
"""Only for training."""
def __init__(self):
super().__init__(
input_params=("noise", "input_audio"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe, noise, input_audio):
if input_audio is None:
return {"latents": noise}
if pipe.scheduler.training:
pipe.load_models_to_device(self.onload_model_names)
input_audio, sample_rate = input_audio
input_audio = torch.clamp(input_audio, -1.0, 1.0)
if input_audio.dim() == 2:
input_audio = input_audio.unsqueeze(0)
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
input_latents = input_latents[:, :noise.shape[1]]
return {"input_latents": input_latents}
def model_fn_ace_step(
dit: AceStepDiTModel,
latents=None,
timestep=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
context_latents=None,
attention_mask=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
decoder_outputs = dit(
hidden_states=latents,
timestep=timestep,
timestep_r=timestep,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)[0]
return decoder_outputs

View File

@@ -0,0 +1,282 @@
import torch
from PIL import Image
from typing import Union, Optional
from tqdm import tqdm
from einops import rearrange
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.joyai_image_dit import JoyAIImageDiT
from ..models.joyai_image_text_encoder import JoyAIImageTextEncoder
from ..models.wan_video_vae import WanVideoVAE
class JoyAIImagePipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler("Wan")
self.text_encoder: JoyAIImageTextEncoder = None
self.dit: JoyAIImageDiT = None
self.vae: WanVideoVAE = None
self.processor = None
self.in_iteration_models = ("dit",)
self.units = [
JoyAIImageUnit_ShapeChecker(),
JoyAIImageUnit_EditImageEmbedder(),
JoyAIImageUnit_PromptEmbedder(),
JoyAIImageUnit_NoiseInitializer(),
JoyAIImageUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_joyai_image
self.compilable_models = ["dit"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
# Processor
processor_config: ModelConfig = None,
# Optional
vram_limit: float = None,
):
pipe = JoyAIImagePipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("joyai_image_text_encoder")
pipe.dit = model_pool.fetch_model("joyai_image_dit")
pipe.vae = model_pool.fetch_model("wan_video_vae")
if processor_config is not None:
processor_config.download_if_necessary()
from transformers import AutoProcessor
pipe.processor = AutoProcessor.from_pretrained(processor_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 5.0,
# Image
edit_image: Image.Image = None,
denoising_strength: float = 1.0,
# Shape
height: int = 1024,
width: int = 1024,
# Randomness
seed: int = None,
# Steps
max_sequence_length: int = 4096,
num_inference_steps: int = 30,
# Tiling
tiled: Optional[bool] = False,
tile_size: Optional[tuple[int, int]] = (30, 52),
tile_stride: Optional[tuple[int, int]] = (15, 26),
# Scheduler
shift: Optional[float] = 4.0,
# Progress bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=shift)
# Parameters
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"edit_image": edit_image,
"denoising_strength": denoising_strength,
"height": height, "width": width,
"seed": seed, "max_sequence_length": max_sequence_length,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
}
# Unit chain
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
# Decode
self.load_models_to_device(['vae'])
latents = rearrange(inputs_shared["latents"], "b n c f h w -> (b n) c f h w")
image = self.vae.decode(latents, device=self.device)[0]
image = self.vae_output_to_image(image, pattern="C 1 H W")
self.load_models_to_device([])
return image
class JoyAIImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: "JoyAIImagePipeline", height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class JoyAIImageUnit_PromptEmbedder(PipelineUnit):
prompt_template_encode = {
'image':
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
'multiple_images':
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n",
'video':
"<|im_start|>system\n \\nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
}
prompt_template_encode_start_idx = {'image': 34, 'multiple_images': 34, 'video': 91}
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt", "positive": "positive"},
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
input_params=("edit_image", "max_sequence_length"),
output_params=("prompt_embeds", "prompt_embeds_mask"),
onload_model_names=("joyai_image_text_encoder",),
)
def process(self, pipe: "JoyAIImagePipeline", prompt, positive, edit_image, max_sequence_length):
pipe.load_models_to_device(self.onload_model_names)
has_image = edit_image is not None
if has_image:
prompt_embeds, prompt_embeds_mask = self._encode_with_image(pipe, prompt, edit_image, max_sequence_length)
else:
prompt_embeds, prompt_embeds_mask = self._encode_text_only(pipe, prompt, max_sequence_length)
return {"prompt_embeds": prompt_embeds, "prompt_embeds_mask": prompt_embeds_mask}
def _encode_with_image(self, pipe, prompt, edit_image, max_sequence_length):
template = self.prompt_template_encode['multiple_images']
drop_idx = self.prompt_template_encode_start_idx['multiple_images']
image_tokens = '<image>\n'
prompt = f"<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n"
prompt = prompt.replace('<image>\n', '<|vision_start|><|image_pad|><|vision_end|>')
prompt = template.format(prompt)
inputs = pipe.processor(text=[prompt], images=[edit_image], padding=True, return_tensors="pt").to(pipe.device)
last_hidden_states = pipe.text_encoder(**inputs)
prompt_embeds = last_hidden_states[:, drop_idx:]
prompt_embeds_mask = inputs['attention_mask'][:, drop_idx:]
if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length:
prompt_embeds = prompt_embeds[:, -max_sequence_length:, :]
prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:]
return prompt_embeds, prompt_embeds_mask
def _encode_text_only(self, pipe, prompt, max_sequence_length):
# TODO: may support for text-only encoding in the future.
raise NotImplementedError("Text-only encoding is not implemented yet. Please provide edit_image for now.")
return prompt_embeds, encoder_attention_mask
class JoyAIImageUnit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "tiled", "tile_size", "tile_stride", "height", "width"),
output_params=("ref_latents", "num_items", "is_multi_item"),
onload_model_names=("wan_video_vae",),
)
def process(self, pipe: "JoyAIImagePipeline", edit_image, tiled, tile_size, tile_stride, height, width):
if edit_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
# Resize edit image to match target dimensions (from ShapeChecker) to ensure ref_latents matches latents
edit_image = edit_image.resize((width, height), Image.LANCZOS)
images = [pipe.preprocess_image(edit_image).transpose(0, 1)]
latents = pipe.vae.encode(images, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=1).to(device=pipe.device, dtype=pipe.torch_dtype)
return {"ref_latents": ref_vae, "edit_image": edit_image}
class JoyAIImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("seed", "height", "width", "rand_device"),
output_params=("noise"),
)
def process(self, pipe: "JoyAIImagePipeline", seed, height, width, rand_device):
latent_h = height // pipe.vae.upsampling_factor
latent_w = width // pipe.vae.upsampling_factor
shape = (1, 1, pipe.vae.z_dim, 1, latent_h, latent_w)
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
return {"noise": noise}
class JoyAIImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe: JoyAIImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
if input_image is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(input_image, Image.Image):
input_image = [input_image]
input_image = [pipe.preprocess_image(img).transpose(0, 1) for img in input_image]
latents = pipe.vae.encode(input_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
input_latents = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(input_image)))
return {"latents": noise, "input_latents": input_latents}
def model_fn_joyai_image(
dit,
latents,
timestep,
prompt_embeds,
prompt_embeds_mask,
ref_latents=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
img = torch.cat([ref_latents, latents], dim=1) if ref_latents is not None else latents
img = dit(
hidden_states=img,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
img = img[:, -latents.size(1):]
return img

View File

@@ -99,6 +99,7 @@ def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend
"""
if waveform.dim() == 3:
waveform = waveform[0]
waveform.cpu()
if backend == "torchcodec":
from torchcodec.encoders import AudioEncoder

View File

@@ -0,0 +1,13 @@
def AceStepConditionEncoderStateDictConverter(state_dict):
new_state_dict = {}
prefix = "encoder."
for key in state_dict:
if key.startswith(prefix):
new_key = key[len(prefix):]
new_state_dict[new_key] = state_dict[key]
if "null_condition_emb" in state_dict:
new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]
return new_state_dict

View File

@@ -0,0 +1,10 @@
def AceStepDiTModelStateDictConverter(state_dict):
new_state_dict = {}
prefix = "decoder."
for key in state_dict:
if key.startswith(prefix):
new_key = key[len(prefix):]
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -0,0 +1,15 @@
def AceStepTextEncoderStateDictConverter(state_dict):
new_state_dict = {}
prefix = "model."
nested_prefix = "model.model."
for key in state_dict:
if key.startswith(nested_prefix):
new_key = key
elif key.startswith(prefix):
new_key = "model." + key
else:
new_key = "model." + key
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -0,0 +1,8 @@
def AceStepTokenizerStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if key.startswith("tokenizer.") or key.startswith("detokenizer."):
new_state_dict[key] = state_dict[key]
return new_state_dict

View File

@@ -0,0 +1,20 @@
def JoyAIImageTextEncoderStateDictConverter(state_dict):
"""Convert HuggingFace Qwen3VL checkpoint keys to DiffSynth wrapper keys.
Mapping (checkpoint -> wrapper):
- lm_head.weight -> model.lm_head.weight
- model.language_model.* -> model.model.language_model.*
- model.visual.* -> model.model.visual.*
"""
state_dict_ = {}
for key in state_dict:
if key == "lm_head.weight":
new_key = "model.lm_head.weight"
elif key.startswith("model.language_model."):
new_key = "model.model." + key[len("model."):]
elif key.startswith("model.visual."):
new_key = "model.model." + key[len("model."):]
else:
new_key = key
state_dict_[new_key] = state_dict[key]
return state_dict_

View File

@@ -0,0 +1,164 @@
# ACE-Step
ACE-Step 1.5 is an open-source music generation model based on DiT architecture, supporting text-to-music, audio cover, repainting and other functionalities, running efficiently on consumer-grade hardware.
## Installation
Before performing model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Running the following code will load the [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 3GB VRAM.
```python
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
## Model Inference
The model is loaded via `AceStepPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
The input parameters for `AceStepPipeline` inference include:
* `prompt`: Text description of the music.
* `cfg_scale`: Classifier-free guidance scale, defaults to 1.0.
* `lyrics`: Lyrics text.
* `task_type`: Task type,可选 values include `"text2music"` (text-to-music), `"cover"` (audio cover), `"repaint"` (repainting), defaults to `"text2music"`.
* `reference_audios`: List of reference audio tensors for timbre reference.
* `src_audio`: Source audio tensor for cover or repaint tasks.
* `denoising_strength`: Denoising strength, controlling how much the output is influenced by source audio, defaults to 1.0.
* `audio_cover_strength`: Audio cover step ratio, controlling how many steps use cover condition in cover tasks, defaults to 1.0.
* `audio_code_string`: Input audio code string for cover tasks with discrete audio codes.
* `repainting_ranges`: List of repainting time ranges (tuples of floats, in seconds) for repaint tasks.
* `repainting_strength`: Repainting intensity, controlling the degree of change in repainted areas, defaults to 1.0.
* `duration`: Audio duration in seconds, defaults to 60.
* `bpm`: Beats per minute, defaults to 100.
* `keyscale`: Musical key scale, defaults to "B minor".
* `timesignature`: Time signature, defaults to "4".
* `vocal_language`: Vocal language, defaults to "unknown".
* `seed`: Random seed.
* `rand_device`: Device for noise generation, defaults to "cpu".
* `num_inference_steps`: Number of inference steps, defaults to 8.
* `shift`: Timestep shift parameter for the scheduler, defaults to 1.0.
## Model Training
Models in the ace_step series are trained uniformly via `examples/ace_step/model_training/train.py`. The script parameters include:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Path to the dataset metadata file.
* `--dataset_repeat`: Number of dataset repeats per epoch.
* `--dataset_num_workers`: Number of processes per DataLoader.
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths to load models from, in JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
* Basic Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
* `--weight_decay`: Weight decay magnitude.
* `--task`: Training task, defaults to `sft`.
* Output Configuration
* `--output_path`: Path to save the model.
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
* `--save_steps`: Interval in training steps to save the model.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to LoRA checkpoint.
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Resolution Configuration
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
* `--num_frames`: Number of frames for video (video generation models only).
* ACE-Step Specific Parameters
* `--tokenizer_path`: Tokenizer path, in format model_id:origin_pattern.
* `--silence_latent_path`: Silence latent path, in format model_id:origin_pattern.
* `--initialize_model_on_cpu`: Whether to initialize models on CPU.
### Example Dataset
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -0,0 +1,154 @@
# JoyAI-Image
JoyAI-Image is a unified multi-modal foundation model open-sourced by JD.com, supporting image understanding, text-to-image generation, and instruction-guided image editing.
## Installation
Before performing model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Running the following code will load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 4GB VRAM.
```python
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
## Model Inference
The model is loaded via `JoyAIImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
The input parameters for `JoyAIImagePipeline` inference include:
* `prompt`: Text prompt describing the desired image editing effect.
* `negative_prompt`: Negative prompt specifying what should not appear in the result, defaults to empty string.
* `cfg_scale`: Classifier-free guidance scale factor, defaults to 5.0. Higher values make the output more closely follow the prompt.
* `edit_image`: Image to be edited.
* `denoising_strength`: Denoising strength controlling how much the input image is repainted, defaults to 1.0.
* `height`: Height of the output image, defaults to 1024. Must be divisible by 16.
* `width`: Width of the output image, defaults to 1024. Must be divisible by 16.
* `seed`: Random seed for reproducibility. Set to `None` for random seed.
* `max_sequence_length`: Maximum sequence length for the text encoder, defaults to 4096.
* `num_inference_steps`: Number of inference steps, defaults to 30. More steps typically yield better quality.
* `tiled`: Whether to enable tiling for reduced VRAM usage, defaults to False.
* `tile_size`: Tile size, defaults to (30, 52).
* `tile_stride`: Tile stride, defaults to (15, 26).
* `shift`: Shift parameter for the scheduler, controlling the Flow Match scheduling curve, defaults to 4.0.
* `progress_bar_cmd`: Progress bar display mode, defaults to tqdm.
## Model Training
Models in the joyai_image series are trained uniformly via `examples/joyai_image/model_training/train.py`. The script parameters include:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Path to the dataset metadata file.
* `--dataset_repeat`: Number of dataset repeats per epoch.
* `--dataset_num_workers`: Number of processes per DataLoader.
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths to load models from, in JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
* Basic Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
* `--weight_decay`: Weight decay magnitude.
* `--task`: Training task, defaults to `sft`.
* Output Configuration
* `--output_path`: Path to save the model.
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
* `--save_steps`: Interval in training steps to save the model.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to LoRA checkpoint.
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Resolution Configuration
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
* `--num_frames`: Number of frames for video (video generation models only).
* JoyAI-Image Specific Parameters
* `--processor_path`: Path to the processor for processing text and image encoder inputs.
* `--initialize_model_on_cpu`: Whether to initialize models on CPU. By default, models are initialized on the accelerator device.
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -31,6 +31,8 @@ Welcome to DiffSynth-Studio's Documentation
Model_Details/Anima
Model_Details/LTX-2
Model_Details/ERNIE-Image
Model_Details/JoyAI-Image
Model_Details/ACE-Step
.. toctree::
:maxdepth: 2

View File

@@ -0,0 +1,164 @@
# ACE-Step
ACE-Step 1.5 是一个开源音乐生成模型,基于 DiT 架构,支持文生音乐、音频翻唱、局部重绘等多种功能,可在消费级硬件上高效运行。
## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
## 快速开始
运行以下代码可以快速加载 [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
```python
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
```
## 模型总览
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
## 模型推理
模型通过 `AceStepPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
`AceStepPipeline` 推理的输入参数包括:
* `prompt`: 音乐文本描述。
* `cfg_scale`: 分类器无条件引导比例,默认为 1.0。
* `lyrics`: 歌词文本。
* `task_type`: 任务类型,可选值包括 `"text2music"`(文生音乐)、`"cover"`(音频翻唱)、`"repaint"`(局部重绘),默认为 `"text2music"`
* `reference_audios`: 参考音频列表Tensor 列表),用于提供音色参考。
* `src_audio`: 源音频Tensor用于 cover 或 repaint 任务。
* `denoising_strength`: 降噪强度,控制输出受源音频的影响程度,默认为 1.0。
* `audio_cover_strength`: 音频翻唱步数比例,控制 cover 任务中前多少步使用翻唱条件,默认为 1.0。
* `audio_code_string`: 输入音频码字符串,用于 cover 任务中直接传入离散音频码。
* `repainting_ranges`: 重绘时间区间(浮点元组列表,单位为秒),用于 repaint 任务。
* `repainting_strength`: 重绘强度,控制重绘区域的变化程度,默认为 1.0。
* `duration`: 音频时长(秒),默认为 60。
* `bpm`: 每分钟节拍数,默认为 100。
* `keyscale`: 音阶调式,默认为 "B minor"。
* `timesignature`: 拍号,默认为 "4"。
* `vocal_language`: 演唱语言,默认为 "unknown"。
* `seed`: 随机种子。
* `rand_device`: 噪声生成设备,默认为 "cpu"。
* `num_inference_steps`: 推理步数,默认为 8。
* `shift`: 调度器时间偏移参数,默认为 1.0。
## 模型训练
ace_step 系列模型统一通过 `examples/ace_step/model_training/train.py` 进行训练,脚本的参数包括:
* 通用训练参数
* 数据集基础配置
* `--dataset_base_path`: 数据集的根目录。
* `--dataset_metadata_path`: 数据集的元数据文件路径。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
* 模型加载配置
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
* 训练基础配置
* `--learning_rate`: 学习率。
* `--num_epochs`: 轮数Epoch
* `--trainable_models`: 可训练的模型,例如 `dit``vae``text_encoder`
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
* `--weight_decay`: 权重衰减大小。
* `--task`: 训练任务,默认为 `sft`
* 输出配置
* `--output_path`: 模型保存路径。
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
* `--save_steps`: 保存模型的训练步数间隔。
* LoRA 配置
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪些层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`
* 梯度配置
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `--gradient_accumulation_steps`: 梯度累积步数。
* 分辨率配置
* `--height`: 图像/视频的高度。留空启用动态分辨率。
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
* `--num_frames`: 视频的帧数(仅视频生成模型)。
* ACE-Step 专有参数
* `--tokenizer_path`: Tokenizer 路径,格式为 model_id:origin_pattern。
* `--silence_latent_path`: 静音隐变量路径,格式为 model_id:origin_pattern。
* `--initialize_model_on_cpu`: 是否在 CPU 上初始化模型。
### 样例数据集
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。

View File

@@ -0,0 +1,154 @@
# JoyAI-Image
JoyAI-Image 是京东开源的统一多模态基础模型,支持图像理解、文生图生成和指令引导的图像编辑。
## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
## 快速开始
运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。
```python
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")
```
## 模型总览
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
## 模型推理
模型通过 `JoyAIImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
`JoyAIImagePipeline` 推理的输入参数包括:
* `prompt`: 文本提示词,用于描述期望的图像编辑效果。
* `negative_prompt`: 负向提示词,指定不希望出现在结果中的内容,默认为空字符串。
* `cfg_scale`: 分类器自由引导的缩放系数,默认为 5.0。值越大,生成结果越贴近 prompt 描述。
* `edit_image`: 待编辑的单张图像。
* `denoising_strength`: 降噪强度,控制输入图像被重绘的程度,默认为 1.0。
* `height`: 输出图像的高度,默认为 1024。需能被 16 整除。
* `width`: 输出图像的宽度,默认为 1024。需能被 16 整除。
* `seed`: 随机种子,用于控制生成的可复现性。设为 `None` 时使用随机种子。
* `max_sequence_length`: 文本编码器处理的最大序列长度,默认为 4096。
* `num_inference_steps`: 推理步数,默认为 30。步数越多生成质量通常越好。
* `tiled`: 是否启用分块处理,用于降低显存占用,默认为 False。
* `tile_size`: 分块大小,默认为 (30, 52)。
* `tile_stride`: 分块步幅,默认为 (15, 26)。
* `shift`: 调度器的 shift 参数,用于控制 Flow Match 的调度曲线,默认为 4.0。
* `progress_bar_cmd`: 进度条显示方式,默认为 tqdm。
## 模型训练
joyai_image 系列模型统一通过 `examples/joyai_image/model_training/train.py` 进行训练,脚本的参数包括:
* 通用训练参数
* 数据集基础配置
* `--dataset_base_path`: 数据集的根目录。
* `--dataset_metadata_path`: 数据集的元数据文件路径。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
* 模型加载配置
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
* 训练基础配置
* `--learning_rate`: 学习率。
* `--num_epochs`: 轮数Epoch
* `--trainable_models`: 可训练的模型,例如 `dit``vae``text_encoder`
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
* `--weight_decay`: 权重衰减大小。
* `--task`: 训练任务,默认为 `sft`
* 输出配置
* `--output_path`: 模型保存路径。
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
* `--save_steps`: 保存模型的训练步数间隔。
* LoRA 配置
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪些层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`
* 梯度配置
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `--gradient_accumulation_steps`: 梯度累积步数。
* 分辨率配置
* `--height`: 图像/视频的高度。留空启用动态分辨率。
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
* `--num_frames`: 视频的帧数(仅视频生成模型)。
* JoyAI-Image 专有参数
* `--processor_path`: Processor 路径,用于处理文本和图像的编码器输入。
* `--initialize_model_on_cpu`: 是否在 CPU 上初始化模型,默认在加速设备上初始化。
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。

View File

@@ -31,6 +31,8 @@
Model_Details/Anima
Model_Details/LTX-2
Model_Details/ERNIE-Image
Model_Details/JoyAI-Image
Model_Details/ACE-Step
.. toctree::
:maxdepth: 2

View File

@@ -0,0 +1,53 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
# input audio codes as reference
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
)
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
audio_code_string = f.read().strip()
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="cover",
audio_code_string=audio_code_string,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo5-with-audio-codes.wav")

View File

@@ -0,0 +1,45 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base-CoverTask/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-CoverTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="cover",
src_audio=src_audio,
audio_cover_strength=0.5,
denoising_strength=0.9,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-cover.wav")

View File

@@ -0,0 +1,47 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base-RepaintTask/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-RepaintTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
# Repainting strength denotes the intensity of repainting area, where 0 means no repainting (keep the original audio) and 1 means full repainting.
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="repaint",
src_audio=src_audio,
repainting_ranges=[(-10, 30), (150, 200)],
repainting_strength=1.0,
duration=210,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-repaint.wav")

View File

@@ -0,0 +1,31 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base.wav")

View File

@@ -0,0 +1,38 @@
"""
Ace-Step 1.5 SFT (supervised fine-tuned) — Text-to-Music inference example.
SFT variant is fine-tuned for specific music styles.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft.wav")

View File

@@ -0,0 +1,36 @@
"""
Ace-Step 1.5 Turbo (continuous, shift 1-5) — Text-to-Music inference example.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
Continuous variant: handles shift range internally, no shift parameter needed.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous.wav")

View File

@@ -0,0 +1,36 @@
"""
Ace-Step 1.5 Turbo (shift=1) — Text-to-Music inference example.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
shift=1: default value, no need to pass.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1.wav")

View File

@@ -0,0 +1,37 @@
"""
Ace-Step 1.5 Turbo (shift=3) — Text-to-Music inference example.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
shift=3: explicitly passed for this variant.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
shift=3,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3.wav")

View File

@@ -0,0 +1,38 @@
"""
Ace-Step 1.5 XL Base (32 layers, hidden_size=2560) — Text-to-Music inference example.
XL variant with larger capacity for higher quality generation.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base.wav")

View File

@@ -0,0 +1,37 @@
"""
Ace-Step 1.5 XL SFT (32 layers, supervised fine-tuned) — Text-to-Music inference example.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft.wav")

View File

@@ -0,0 +1,36 @@
"""
Ace-Step 1.5 XL Turbo (32 layers, fast generation) — Text-to-Music inference example.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
shift=3: explicitly passed for this variant.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo.wav")

View File

@@ -0,0 +1,73 @@
"""
Ace-Step 1.5 (main model, turbo) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Turbo model: uses num_inference_steps=8, cfg_scale=1.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from modelscope import dataset_snapshot_download
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-low-vram.wav")
# input audio codes as reference
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
)
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
audio_code_string = f.read().strip()
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="cover",
audio_code_string=audio_code_string,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo5-with-audio-codes-low-vram.wav")

View File

@@ -0,0 +1,57 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base-CoverTask/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-CoverTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="cover",
src_audio=src_audio,
audio_cover_strength=0.5,
denoising_strength=0.9,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-cover.wav")

View File

@@ -0,0 +1,59 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base-RepaintTask/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-RepaintTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
# Repainting strength denotes the intensity of repainting area, where 0 means no repainting (keep the original audio) and 1 means full repainting.
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="repaint",
src_audio=src_audio,
repainting_ranges=[(-10, 30), (150, 200)],
repainting_strength=1.0,
duration=210,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-repaint.wav")

View File

@@ -0,0 +1,49 @@
"""
Ace-Step 1.5 Base — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-low-vram.wav")

View File

@@ -0,0 +1,51 @@
"""
Ace-Step 1.5 SFT (supervised fine-tuned) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
SFT variant is fine-tuned for specific music styles.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft-low-vram.wav")

View File

@@ -0,0 +1,49 @@
"""
Ace-Step 1.5 Turbo (continuous, shift 1-5) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
Continuous variant: handles shift range internally, no shift parameter needed.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous-low-vram.wav")

View File

@@ -0,0 +1,49 @@
"""
Ace-Step 1.5 Turbo (shift=1) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
shift=1: default value, no need to pass.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1-low-vram.wav")

View File

@@ -0,0 +1,50 @@
"""
Ace-Step 1.5 Turbo (shift=3) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
shift=3: explicitly passed for this variant.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
shift=3,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3-low-vram.wav")

View File

@@ -0,0 +1,51 @@
"""
Ace-Step 1.5 XL Base — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
torch.cuda.reset_peak_memory_stats("cuda")
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base-low-vram.wav")

View File

@@ -0,0 +1,50 @@
"""
Ace-Step 1.5 XL SFT (supervised fine-tuned) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft-low-vram.wav")

View File

@@ -0,0 +1,48 @@
"""
Ace-Step 1.5 XL Turbo — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo-low-vram.wav")

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/Ace-Step1.5/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/Ace-Step1.5/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5/metadata.json" \
--model_id_with_origin_paths "ACE-Step/Ace-Step1.5:acestep-v15-turbo/model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/Ace-Step1.5_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-base/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-base/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-base:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-base_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-sft/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-sft/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-sft:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-sft_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-continuous/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-continuous:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-continuous_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift1/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift1:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-shift1_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift3/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift3:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-shift3_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-base/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-base:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-base_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-sft/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-sft:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-sft_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-turbo/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-turbo:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-turbo_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/Ace-Step1.5/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/Ace-Step1.5/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5/metadata.json" \
--model_id_with_origin_paths "ACE-Step/Ace-Step1.5:acestep-v15-turbo/model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/Ace-Step1.5_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-base/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-base/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-base:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-base_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-sft/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-sft/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-sft:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-sft_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-continuous/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-continuous:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-continuous_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift1/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift1:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-shift1_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift3/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift3:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-shift3_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-base/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-base:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-base_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-sft/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-sft:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-sft_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-turbo/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-turbo:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-turbo_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,144 @@
import os
import torch
import math
import argparse
import accelerate
from diffsynth.core import UnifiedDataset
from diffsynth.core.data.operators import ToAbsolutePath, LoadPureAudioWithTorchaudio
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.diffusion import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class AceStepTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None, silence_latent_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
fp8_models=None,
offload_models=None,
device="cpu",
task="sft",
):
super().__init__()
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
text_tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"))
silence_latent_config = self.parse_path_or_model_id(silence_latent_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"))
self.pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16, device=device, model_configs=model_configs,
text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config,
)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
preset_lora_path, preset_lora_model,
task=task,
)
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
}
def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"], "positive": True}
inputs_nega = {"positive": False}
duration = math.floor(data['audio'][0].shape[1] / data['audio'][1]) if data.get("audio") is not None else data.get("duration", 60)
inputs_shared = {
"input_audio": data["audio"],
"lyrics": data["lyrics"],
"task_type": "text2music",
"duration": duration,
"bpm": data.get("bpm", 100),
"keyscale": data.get("keyscale", "C major"),
"timesignature": data.get("timesignature", "4"),
"vocal_language": data.get("vocal_language", "unknown"),
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
if inputs is None: inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
loss = self.task_to_loss[self.task](self.pipe, *inputs)
return loss
def ace_step_parser():
parser = argparse.ArgumentParser(description="ACE-Step training.")
parser = add_general_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Tokenizer path in format model_id:origin_pattern.")
parser.add_argument("--silence_latent_path", type=str, default=None, help="Silence latent path in format model_id:origin_pattern.")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
return parser
if __name__ == "__main__":
parser = ace_step_parser()
args = parser.parse_args()
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=None,
special_operator_map={
"audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(
target_sample_rate=48000,
),
},
)
model = AceStepTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
tokenizer_path=args.tokenizer_path,
silence_latent_path=args.silence_latent_path,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
preset_lora_path=args.preset_lora_path,
preset_lora_model=args.preset_lora_model,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
fp8_models=args.fp8_models,
offload_models=args.offload_models,
task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
launcher_map = {
"sft:data_process": launch_data_process_task,
"sft": launch_training_task,
"sft:train": launch_training_task,
}
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/Ace-Step1.5_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "Ace-Step1.5_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-base_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-sft_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-turbo-continuous_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-turbo-shift1_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-turbo-shift3_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-xl-base_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-xl-sft_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-xl-turbo_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=8,
cfg_scale=1.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo_full.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/Ace-Step1.5_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "Ace-Step1.5_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-base_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-sft_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-continuous_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-shift1_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-shift3_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-base_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-sft_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-turbo_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo_lora.wav")

View File

@@ -0,0 +1,39 @@
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=1,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit.png")

View File

@@ -0,0 +1,51 @@
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
import torch
from PIL import Image
from modelscope import dataset_snapshot_download
# Download dataset
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
)
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
# Use first sample from dataset
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
prompt = "将裙子改为粉色"
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
output = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
output.save("output_joyai_edit_low_vram.png")

View File

@@ -0,0 +1,35 @@
# Dataset: data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/joyai_image/model_training/train.py \
--dataset_base_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" \
--dataset_metadata_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv" \
--max_pixels 1048576 \
--dataset_repeat 1 \
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/JoyAI-Image-Edit-full-cache" \
--use_gradient_checkpointing \
--find_unused_parameters \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--task "sft:data_process"
accelerate launch --config_file examples/joyai_image/model_training/full/accelerate_config_zero3.yaml \
examples/joyai_image/model_training/train.py \
--dataset_base_path "./models/train/JoyAI-Image-Edit-full-cache" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/JoyAI-Image-Edit-full" \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--task "sft:train"

View File

@@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@@ -0,0 +1,39 @@
# Dataset: data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "joyai_image/JoyAI-Image-Edit/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/joyai_image/model_training/train.py \
--dataset_base_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit" \
--dataset_metadata_path "./data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/metadata.csv" \
--max_pixels 1048576 \
--dataset_repeat 1 \
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:JoyAI-Image-Und/model*.safetensors,jd-opensource/JoyAI-Image-Edit:vae/Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/JoyAI-Image-Edit-split-cache" \
--lora_base_model "dit" \
--lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--task "sft:data_process"
accelerate launch examples/joyai_image/model_training/train.py \
--dataset_base_path "./models/train/JoyAI-Image-Edit-split-cache" \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "jd-opensource/JoyAI-Image-Edit:transformer/transformer.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/JoyAI-Image-Edit-lora" \
--lora_base_model "dit" \
--lora_target_modules "img_attn_qkv,txt_attn_qkv,img_attn_proj,txt_attn_proj" \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--data_file_keys "image,edit_image" \
--extra_inputs "edit_image" \
--task "sft:train"

View File

@@ -0,0 +1,138 @@
import torch, os, argparse, accelerate
from diffsynth.core import UnifiedDataset
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
from diffsynth.diffusion import *
from diffsynth.core.data.operators import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class JoyAIImageTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
processor_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
fp8_models=None,
offload_models=None,
device="cpu",
task="sft",
):
super().__init__()
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
processor_config = ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/") if processor_path is None else ModelConfig(processor_path)
self.pipe = JoyAIImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, processor_config=processor_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
preset_lora_path, preset_lora_model,
task=task,
)
# Other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
}
def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
if inputs is None: inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
loss = self.task_to_loss[self.task](self.pipe, *inputs)
return loss
def joyai_image_parser():
parser = argparse.ArgumentParser(description="JoyAI-Image training.")
parser = add_general_config(parser)
parser = add_image_size_config(parser)
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor.")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
return parser
if __name__ == "__main__":
parser = joyai_image_parser()
args = parser.parse_args()
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=UnifiedDataset.default_image_operator(
base_path=args.dataset_base_path,
max_pixels=args.max_pixels,
height=args.height,
width=args.width,
height_division_factor=16,
width_division_factor=16,
),
)
model = JoyAIImageTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
processor_path=args.processor_path,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
preset_lora_path=args.preset_lora_path,
preset_lora_model=args.preset_lora_model,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
fp8_models=args.fp8_models,
offload_models=args.offload_models,
task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
launcher_map = {
"sft:data_process": launch_data_process_task,
"sft": launch_training_task,
"sft:train": launch_training_task,
}
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)

View File

@@ -0,0 +1,32 @@
import torch
from PIL import Image
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
from diffsynth import load_state_dict
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
)
state_dict = load_state_dict("models/train/JoyAI-Image-Edit_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "将裙子改为粉色"
edit_image = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB")
image = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=50,
cfg_scale=5.0,
)
image.save("image_full.jpg")

View File

@@ -0,0 +1,30 @@
import torch
from PIL import Image
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
pipe = JoyAIImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors"),
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth"),
],
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
)
pipe.load_lora(pipe.dit, "models/train/JoyAI-Image-Edit-lora/epoch-4.safetensors")
prompt = "将裙子改为粉色"
edit_image = Image.open("data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit/edit/image1.jpg").convert("RGB")
image = pipe(
prompt=prompt,
edit_image=edit_image,
height=1024,
width=1024,
seed=0,
num_inference_steps=30,
cfg_scale=5.0,
)
image.save("image_lora.jpg")

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "diffsynth"
version = "2.0.7"
version = "2.0.9"
description = "Enjoy the magic of Diffusion models!"
authors = [{name = "ModelScope Team"}]
license = {text = "Apache-2.0"}