Compare commits

..

14 Commits
sd ... ace-step

Author SHA1 Message Date
mi804
3da625432e path 2026-04-23 18:09:16 +08:00
mi804
002e3cdb74 docs 2026-04-23 18:02:58 +08:00
mi804
29bf66cdc9 Merge branch 'main' of https://github.com/modelscope/DiffSynth-Studio 2026-04-23 17:39:10 +08:00
mi804
a80fb84220 style 2026-04-23 17:31:34 +08:00
mi804
394db06d86 codes 2026-04-23 16:52:59 +08:00
mi804
1186379139 noncover 2026-04-22 21:36:30 +08:00
mi804
f2e3427566 reference audio input 2026-04-22 19:16:04 +08:00
mi804
c53c813c12 ace-step train 2026-04-22 17:58:10 +08:00
mi804
b0680ef711 low_vram 2026-04-22 12:47:38 +08:00
mi804
f5a3201d42 t2m 2026-04-21 20:12:15 +08:00
mi804
95cfb77881 t2m 2026-04-21 19:42:57 +08:00
mi804
9d09e0431c acestep t2m 2026-04-21 13:16:15 +08:00
mi804
a604d76339 pipeline_t2m 2026-04-17 17:45:52 +08:00
mi804
36c203da57 model-code 2026-04-17 17:06:26 +08:00
109 changed files with 5970 additions and 4703 deletions

205
README.md
View File

@@ -34,7 +34,7 @@ We believe that a well-developed open-source code framework can lower the thresh
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand. > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
- **April 24, 2026** We add support for Stable Diffusion v1.5 and SDXL, including inference, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/Stable-Diffusion.md), [documentation](/docs/en/Model_Details/Stable-Diffusion-XL.md) and [example code](/examples/stable_diffusion/). - **April 23, 2026** ACE-Step open-sourced, welcome a new member to the audio model family! Support includes text-to-music generation, low VRAM inference, and LoRA training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/ACE-Step.md) and [example code](/examples/ace_step/).
- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/). - **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/).
@@ -301,129 +301,6 @@ Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image
</details> </details>
#### Stable Diffusion: [/docs/en/Model_Details/Stable-Diffusion.md](/docs/en/Model_Details/Stable-Diffusion.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 2GB VRAM.
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for Stable Diffusion is available at: [/examples/stable_diffusion/](/examples/stable_diffusion/)
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
</details>
#### Stable Diffusion XL: [/docs/en/Model_Details/Stable-Diffusion-XL.md](/docs/en/Model_Details/Stable-Diffusion-XL.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 6GB VRAM.
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars",
negative_prompt="",
cfg_scale=5.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image.jpg")
```
</details>
<details>
<summary>Examples</summary>
Example code for Stable Diffusion XL is available at: [/examples/stable_diffusion_xl/](/examples/stable_diffusion_xl/)
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
</details>
#### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md) #### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md)
<details> <details>
@@ -1141,6 +1018,86 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
</details> </details>
### Audio Synthesis
#### ACE-Step: [/docs/en/Model_Details/ACE-Step.md](/docs/en/Model_Details/ACE-Step.md)
<details>
<summary>Quick Start</summary>
Running the following code will quickly load the [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
```python
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
```
</details>
<details>
<summary>Examples</summary>
Example code for ACE-Step is available at: [/examples/ace_step/](/examples/ace_step/)
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|-|-|-|-|-|-|-|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
</details>
## Innovative Achievements ## Innovative Achievements
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements. DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.

View File

@@ -34,7 +34,7 @@ DiffSynth 目前包括两个开源项目:
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。 > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责因此新功能的开发进展会比较缓慢issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
- **2026年4月24日** 我们新增对 Stable Diffusion v1.5 和 SDXL 的支持,包括推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/Stable-Diffusion.md)和[示例代码](/examples/stable_diffusion/)。 - **2026年4月23日** ACE-Step 开源,欢迎加入音频生成模型家族!支持文生音乐推理、低显存推理和 LoRA 训练能力。详情请参考[文档](/docs/zh/Model_Details/ACE-Step.md)和[示例代码](/examples/ace_step/)。
- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。 - **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。
@@ -301,129 +301,6 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
</details> </details>
#### Stable Diffusion[/docs/zh/Model_Details/Stable-Diffusion.md](/docs/zh/Model_Details/Stable-Diffusion.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) 模型并进行推理。显存管理已启用,框架会自动根据剩余显存控制模型参数的加载,最低 2GB 显存即可运行。
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")
```
</details>
<details>
<summary>示例代码</summary>
Stable Diffusion 的示例代码位于:[/examples/stable_diffusion/](/examples/stable_diffusion/)
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|-|-|-|-|-|-|-|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
</details>
#### Stable Diffusion XL[/docs/zh/Model_Details/Stable-Diffusion-XL.md](/docs/zh/Model_Details/Stable-Diffusion-XL.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) 模型并进行推理。显存管理已启用,框架会自动根据剩余显存控制模型参数的加载,最低 6GB 显存即可运行。
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars",
negative_prompt="",
cfg_scale=5.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image.jpg")
```
</details>
<details>
<summary>示例代码</summary>
Stable Diffusion XL 的示例代码位于:[/examples/stable_diffusion_xl/](/examples/stable_diffusion_xl/)
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|-|-|-|-|-|-|-|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
</details>
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md) #### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
<details> <details>
@@ -1141,6 +1018,86 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
</details> </details>
### 音频生成模型
#### ACE-Step: [/docs/zh/Model_Details/ACE-Step.md](/docs/zh/Model_Details/ACE-Step.md)
<details>
<summary>快速开始</summary>
运行以下代码可以快速加载 [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
```python
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
```
</details>
<details>
<summary>示例代码</summary>
ACE-Step 的示例代码位于:[/examples/ace_step/](/examples/ace_step/)
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|-|-|-|-|-|-|-|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
</details>
## 创新成果 ## 创新成果
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。 DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。

View File

@@ -900,61 +900,6 @@ mova_series = [
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge", "model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
}, },
] ]
stable_diffusion_xl_series = [
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
"model_hash": "142b114f67f5ab3a6d83fb5788f12ded",
"model_name": "stable_diffusion_xl_unet",
"model_class": "diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel",
"extra_kwargs": {"attention_head_dim": [5, 10, 20], "transformer_layers_per_block": [1, 2, 10], "use_linear_projection": True, "addition_embed_type": "text_time", "addition_time_embed_dim": 256, "projection_class_embeddings_input_dim": 2816},
},
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors")
"model_hash": "98cc34ccc5b54ae0e56bdea8688dcd5a",
"model_name": "stable_diffusion_xl_text_encoder",
"model_class": "diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_xl_text_encoder.SDXLTextEncoder2StateDictConverter",
},
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
"model_name": "stable_diffusion_text_encoder",
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "13115dd45a6e1c39860f91ab073b8a78",
"model_name": "stable_diffusion_xl_vae",
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
"extra_kwargs": {"scaling_factor": 0.13025, "sample_size": 1024, "force_upcast": True},
},
]
stable_diffusion_series = [
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors")
"model_hash": "ffd1737ae9df7fd43f5fbed653bdad67",
"model_name": "stable_diffusion_text_encoder",
"model_class": "diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_text_encoder.SDTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "f86d5683ed32433be8ca69969c67ba69",
"model_name": "stable_diffusion_vae",
"model_class": "diffsynth.models.stable_diffusion_vae.StableDiffusionVAE",
"state_dict_converter": "diffsynth.utils.state_dict_converters.stable_diffusion_vae.SDVAEStateDictConverter",
},
{
# Example: ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors")
"model_hash": "025a4b86a84829399d89f613e580757b",
"model_name": "stable_diffusion_unet",
"model_class": "diffsynth.models.stable_diffusion_unet.UNet2DConditionModel",
},
]
joyai_image_series = [ joyai_image_series = [
{ {
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth") # Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
@@ -971,4 +916,86 @@ joyai_image_series = [
}, },
] ]
MODEL_CONFIGS = stable_diffusion_xl_series + stable_diffusion_series + qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series + z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series ace_step_series = [
# === Standard DiT variants (24 layers, hidden_size=2048) ===
# Covers: turbo, turbo-shift1, turbo-shift3, turbo-continuous, base, sft
# All share identical state_dict structure → same hash
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_dit",
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
},
# === XL DiT variants (32 layers, hidden_size=2560) ===
# Covers: xl-base, xl-sft, xl-turbo
{
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_dit",
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
"extra_kwargs": {
"hidden_size": 2560,
"intermediate_size": 9728,
"num_hidden_layers": 32,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"encoder_hidden_size": 2048,
"layer_types": ["sliding_attention", "full_attention"] * 16,
},
},
# === Conditioner (shared by all DiT variants, same architecture) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_conditioner",
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
},
# === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) ===
{
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_conditioner",
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
},
# === Qwen3-Embedding (text encoder) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors")
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
"model_name": "ace_step_text_encoder",
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter",
},
# === VAE (AutoencoderOobleck CNN) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "51420834e54474986a7f4be0e4d6f687",
"model_name": "ace_step_vae",
"model_class": "diffsynth.models.ace_step_vae.AceStepVAE",
},
# === Tokenizer (VAE latent discretization: tokenizer + detokenizer) ===
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_tokenizer",
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
},
# === XL Tokenizer (XL models share same tokenizer architecture) ===
{
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
"model_hash": "3a28a410c2246f125153ef792d8bc828",
"model_name": "ace_step_tokenizer",
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
},
]
MODEL_CONFIGS = (
qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
)

View File

@@ -295,44 +295,43 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
}, },
"diffsynth.models.stable_diffusion_unet.UNet2DConditionModel": { # ACE-Step module maps
"diffsynth.models.ace_step_dit.AceStepDiTModel": {
"diffsynth.models.ace_step_dit.AceStepDiTLayer": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
}, },
"diffsynth.models.stable_diffusion_vae.StableDiffusionVAE": { "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.stable_diffusion_vae.Upsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.stable_diffusion_vae.Downsample2D": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_text_encoder.SDTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
}, },
"diffsynth.models.stable_diffusion_xl_unet.SDXLUNet2DConditionModel": { "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.stable_diffusion_xl_text_encoder.SDXLTextEncoder2": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPTextTransformer": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPEncoderLayer": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.clip.modeling_clip.CLIPAttention": "diffsynth.core.vram.layers.AutoWrappedModule", "transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ace_step_vae.AceStepVAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.ace_step_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ace_step_tokenizer.AceStepTokenizer": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"vector_quantize_pytorch.ResidualFSQ": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
}, },
} }

View File

@@ -3,6 +3,7 @@ import torch, torchvision, imageio, os
import imageio.v3 as iio import imageio.v3 as iio
from PIL import Image from PIL import Image
import torchaudio import torchaudio
from diffsynth.utils.data.audio import read_audio
class DataProcessingPipeline: class DataProcessingPipeline:
@@ -276,3 +277,27 @@ class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
except: except:
warnings.warn(f"Cannot load audio in {data}. The audio will be `None`.") warnings.warn(f"Cannot load audio in {data}. The audio will be `None`.")
return None return None
class LoadPureAudioWithTorchaudio(DataProcessingOperator):
def __init__(self, target_sample_rate=None, target_duration=None):
self.target_sample_rate = target_sample_rate
self.target_duration = target_duration
self.resample = True if target_sample_rate is not None else False
def __call__(self, data: str):
try:
waveform, sample_rate = read_audio(data, resample=self.resample, resample_rate=self.target_sample_rate)
if self.target_duration is not None:
target_samples = int(self.target_duration * sample_rate)
current_samples = waveform.shape[-1]
if current_samples > target_samples:
waveform = waveform[..., :target_samples]
elif current_samples < target_samples:
padding = target_samples - current_samples
waveform = torch.nn.functional.pad(waveform, (0, padding))
return waveform, sample_rate
except Exception as e:
warnings.warn(f"Cannot load audio in '{data}' due to '{e}'. The audio will be `None`.")
return None

View File

@@ -152,7 +152,7 @@ class BasePipeline(torch.nn.Module):
# remove batch dim # remove batch dim
if audio_output.ndim == 3: if audio_output.ndim == 3:
audio_output = audio_output.squeeze(0) audio_output = audio_output.squeeze(0)
return audio_output.float() return audio_output.float().cpu()
def load_models_to_device(self, model_names): def load_models_to_device(self, model_names):
if self.vram_management_enabled: if self.vram_management_enabled:

View File

@@ -1,107 +0,0 @@
import torch, math
class DDIMScheduler():
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
self.num_train_timesteps = num_train_timesteps
if beta_schedule == "scaled_linear":
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
elif beta_schedule == "linear":
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
if rescale_zero_terminal_snr:
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
self.alphas_cumprod = self.alphas_cumprod.tolist()
self.set_timesteps(10)
self.prediction_type = prediction_type
self.training = False
def rescale_zero_terminal_snr(self, alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
return alphas_bar
def set_timesteps(self, num_inference_steps, denoising_strength=1.0, training=False, **kwargs):
# The timesteps are aligned to 999...0, which is different from other implementations,
# but I think this implementation is more reasonable in theory.
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
num_inference_steps = min(num_inference_steps, max_timestep + 1)
if num_inference_steps == 1:
self.timesteps = torch.Tensor([max_timestep])
else:
step_length = max_timestep / (num_inference_steps - 1)
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
self.training = training
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
if self.prediction_type == "epsilon":
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
prev_sample = sample * weight_x + model_output * weight_e
elif self.prediction_type == "v_prediction":
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
prev_sample = sample * weight_x + model_output * weight_e
else:
raise NotImplementedError(f"{self.prediction_type} is not implemented")
return prev_sample
def step(self, model_output, timestep, sample, to_final=False):
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
if to_final or timestep_id + 1 >= len(self.timesteps):
alpha_prod_t_prev = 1.0
else:
timestep_prev = int(self.timesteps[timestep_id + 1])
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
def return_to_timestep(self, timestep, sample, sample_stablized):
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
return noise_pred
def add_noise(self, original_samples, noise, timestep):
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def training_target(self, sample, noise, timestep):
if self.prediction_type == "epsilon":
return noise
else:
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return target
def training_weight(self, timestep):
return 1.0

View File

@@ -4,7 +4,7 @@ from typing_extensions import Literal
class FlowMatchScheduler(): class FlowMatchScheduler():
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"): def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image", "ACE-Step"] = "FLUX.1"):
self.set_timesteps_fn = { self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux, "FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan, "Wan": FlowMatchScheduler.set_timesteps_wan,
@@ -14,6 +14,7 @@ class FlowMatchScheduler():
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2, "LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning, "Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
"ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image, "ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
"ACE-Step": FlowMatchScheduler.set_timesteps_ace_step,
}.get(template, FlowMatchScheduler.set_timesteps_flux) }.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000 self.num_train_timesteps = 1000
@@ -142,6 +143,26 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps timesteps = sigmas * num_train_timesteps
return sigmas, timesteps return sigmas, timesteps
@staticmethod
def set_timesteps_ace_step(num_inference_steps=8, denoising_strength=1.0, shift=3.0):
"""ACE-Step Flow Matching scheduler.
Timesteps range from 1.0 to 0.0 (not multiplied by 1000).
Shift transformation: t = shift * t / (1 + (shift - 1) * t)
Args:
num_inference_steps: Number of diffusion steps.
denoising_strength: Denoising strength (1.0 = full denoising).
shift: Timestep shift parameter (default 3.0 for turbo).
"""
num_train_timesteps = 1000
sigma_start = denoising_strength
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
if shift is not None and shift != 1.0:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod @staticmethod
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None): def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
sigma_min = 0.0 sigma_min = 0.0

View File

@@ -0,0 +1,695 @@
# Copyright 2025 The ACESTEO Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
from ..core.attention import attention_forward
from ..core.gradient import gradient_checkpoint_forward
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutput
from transformers.processing_utils import Unpack
from transformers.utils import can_return_tuple, logging
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3MLP,
Qwen3RMSNorm,
Qwen3RotaryEmbedding,
apply_rotary_pos_emb,
)
logger = logging.get_logger(__name__)
def create_4d_mask(
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
is_sliding_window: bool = False,
is_causal: bool = True,
) -> torch.Tensor:
indices = torch.arange(seq_len, device=device)
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
if is_causal:
valid_mask = valid_mask & (diff >= 0)
if is_sliding_window and sliding_window is not None:
if is_causal:
valid_mask = valid_mask & (diff <= sliding_window)
else:
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
valid_mask = valid_mask & padding_mask_4d
min_dtype = torch.finfo(dtype).min
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
mask_tensor.masked_fill_(valid_mask, 0.0)
return mask_tensor
def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
hidden_cat = torch.cat([hidden1, hidden2], dim=1)
mask_cat = torch.cat([mask1, mask2], dim=1)
B, L, D = hidden_cat.shape
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True)
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
lengths = mask_cat.sum(dim=1)
new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
return hidden_left, new_mask
class Lambda(nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
class AceStepAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
is_cross_attention: bool = False,
is_causal: bool = False,
):
super().__init__()
self.layer_idx = layer_idx
self.head_dim = head_dim or hidden_size // num_attention_heads
self.num_key_value_groups = num_attention_heads // num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.attention_dropout = attention_dropout
if is_cross_attention:
is_causal = False
self.is_causal = is_causal
self.is_cross_attention = is_cross_attention
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.attention_type = layer_types[layer_idx]
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
if is_cross_attention:
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
curr_past_key_value = past_key_value.cross_attention_cache
if not is_updated:
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
past_key_value.is_updated[self.layer_idx] = True
else:
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
else:
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if self.num_key_value_groups > 1:
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
attn_output = attention_forward(
query_states, key_states, value_states,
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
attn_mask=attention_mask,
)
attn_weights = None
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class AceStepEncoderLayer(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.layer_idx = layer_idx
self.self_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
is_cross_attention=False,
is_causal=False,
)
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
mlp_config = type('Config', (), {
'hidden_size': hidden_size,
'intermediate_size': intermediate_size,
'hidden_act': 'silu',
})()
self.mlp = Qwen3MLP(mlp_config)
self.attention_type = layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=False,
past_key_value=None,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class AceStepLyricEncoder(nn.Module):
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
use_cache: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
text_hidden_dim: int = 1024,
num_lyric_encoder_hidden_layers: int = 8,
**kwargs,
):
super().__init__()
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
self.text_hidden_dim = text_hidden_dim
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.use_cache = use_cache
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.embed_tokens = nn.Linear(text_hidden_dim, hidden_size)
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
rope_config = type('RopeConfig', (), {
'hidden_size': hidden_size,
'num_attention_heads': num_attention_heads,
'num_key_value_heads': num_key_value_heads,
'head_dim': head_dim,
'max_position_embeddings': max_position_embeddings,
'rope_theta': rope_theta,
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
'rms_norm_eps': rms_norm_eps,
'attention_bias': attention_bias,
'attention_dropout': attention_dropout,
'hidden_act': 'silu',
'intermediate_size': intermediate_size,
'layer_types': self.layer_types,
'sliding_window': sliding_window,
'_attn_implementation': self._attn_implementation,
})()
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
self.gradient_checkpointing = False
self.layers = nn.ModuleList([
AceStepEncoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=self.layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
for layer_idx in range(num_lyric_encoder_hidden_layers)
])
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutput:
output_attentions = output_attentions if output_attentions is not None else False
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
assert input_ids is None, "Only `inputs_embeds` is supported for the lyric encoder."
assert attention_mask is not None, "Attention mask must be provided for the lyric encoder."
assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder."
inputs_embeds = self.embed_tokens(inputs_embeds)
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
seq_len = inputs_embeds.shape[1]
dtype = inputs_embeds.dtype
device = inputs_embeds.device
full_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=None,
is_sliding_window=False, is_causal=False
)
sliding_attn_mask = None
if self.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=self.sliding_window,
is_sliding_window=True, is_causal=False
)
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for layer_module in self.layers[: self.num_lyric_encoder_hidden_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer_module(
hidden_states, position_embeddings,
self_attn_mask_mapping[layer_module.attention_type],
position_ids, output_attentions,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class AceStepTimbreEncoder(nn.Module):
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
use_cache: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
timbre_hidden_dim: int = 64,
num_timbre_encoder_hidden_layers: int = 4,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.use_cache = use_cache
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.timbre_hidden_dim = timbre_hidden_dim
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.embed_tokens = nn.Linear(timbre_hidden_dim, hidden_size)
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
rope_config = type('RopeConfig', (), {
'hidden_size': hidden_size,
'num_attention_heads': num_attention_heads,
'num_key_value_heads': num_key_value_heads,
'head_dim': head_dim,
'max_position_embeddings': max_position_embeddings,
'rope_theta': rope_theta,
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
'rms_norm_eps': rms_norm_eps,
'attention_bias': attention_bias,
'attention_dropout': attention_dropout,
'hidden_act': 'silu',
'intermediate_size': intermediate_size,
'layer_types': self.layer_types,
'sliding_window': sliding_window,
'_attn_implementation': self._attn_implementation,
})()
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
self.gradient_checkpointing = False
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size))
self.layers = nn.ModuleList([
AceStepEncoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=self.layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
for layer_idx in range(num_timbre_encoder_hidden_layers)
])
def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
N, d = timbre_embs_packed.shape
device = timbre_embs_packed.device
dtype = timbre_embs_packed.dtype
B = int(refer_audio_order_mask.max().item() + 1)
counts = torch.bincount(refer_audio_order_mask, minlength=B)
max_count = counts.max().item()
sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
sorted_batch_ids = refer_audio_order_mask[sorted_indices]
positions = torch.arange(N, device=device)
batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]])
positions_in_sorted = positions - batch_starts[sorted_batch_ids]
inverse_indices = torch.empty_like(sorted_indices)
inverse_indices[sorted_indices] = torch.arange(N, device=device)
positions_in_batch = positions_in_sorted[inverse_indices]
indices_2d = refer_audio_order_mask * max_count + positions_in_batch
one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype)
timbre_embs_flat = one_hot.t() @ timbre_embs_packed
timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
mask_flat = (one_hot.sum(dim=0) > 0).long()
new_mask = mask_flat.reshape(B, max_count)
return timbre_embs_unpack, new_mask
@can_return_tuple
def forward(
self,
refer_audio_acoustic_hidden_states_packed: Optional[torch.FloatTensor] = None,
refer_audio_order_mask: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutput:
inputs_embeds = refer_audio_acoustic_hidden_states_packed
inputs_embeds = self.embed_tokens(inputs_embeds)
seq_len = inputs_embeds.shape[1]
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
dtype = inputs_embeds.dtype
device = inputs_embeds.device
full_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=None,
is_sliding_window=False, is_causal=False
)
sliding_attn_mask = None
if self.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=self.sliding_window,
is_sliding_window=True, is_causal=False
)
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer_module in self.layers[: self.num_timbre_encoder_hidden_layers]:
layer_outputs = layer_module(
hidden_states, position_embeddings,
self_attn_mask_mapping[layer_module.attention_type],
position_ids,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states[:, 0, :]
# For packed input: reshape [1, T, D] -> [T, D] for unpacking
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
return timbre_embs_unpack, timbre_embs_mask
class AceStepConditionEncoder(nn.Module):
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
use_cache: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
text_hidden_dim: int = 1024,
timbre_hidden_dim: int = 64,
num_lyric_encoder_hidden_layers: int = 8,
num_timbre_encoder_hidden_layers: int = 4,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.use_cache = use_cache
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.text_hidden_dim = text_hidden_dim
self.timbre_hidden_dim = timbre_hidden_dim
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.text_projector = nn.Linear(text_hidden_dim, hidden_size, bias=False)
self.null_condition_emb = nn.Parameter(torch.randn(1, 1, hidden_size))
self.lyric_encoder = AceStepLyricEncoder(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
use_sliding_window=use_sliding_window,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
text_hidden_dim=text_hidden_dim,
num_lyric_encoder_hidden_layers=num_lyric_encoder_hidden_layers,
)
self.timbre_encoder = AceStepTimbreEncoder(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
use_sliding_window=use_sliding_window,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
timbre_hidden_dim=timbre_hidden_dim,
num_timbre_encoder_hidden_layers=num_timbre_encoder_hidden_layers,
)
def forward(
self,
text_hidden_states: Optional[torch.FloatTensor] = None,
text_attention_mask: Optional[torch.Tensor] = None,
lyric_hidden_states: Optional[torch.LongTensor] = None,
lyric_attention_mask: Optional[torch.Tensor] = None,
reference_latents: Optional[torch.Tensor] = None,
refer_audio_order_mask: Optional[torch.LongTensor] = None,
):
text_hidden_states = self.text_projector(text_hidden_states)
lyric_encoder_outputs = self.lyric_encoder(
inputs_embeds=lyric_hidden_states,
attention_mask=lyric_attention_mask,
)
lyric_hidden_states = lyric_encoder_outputs.last_hidden_state
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(reference_latents, refer_audio_order_mask)
encoder_hidden_states, encoder_attention_mask = pack_sequences(
lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
)
encoder_hidden_states, encoder_attention_mask = pack_sequences(
encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask
)
return encoder_hidden_states, encoder_attention_mask

View File

@@ -0,0 +1,901 @@
# Copyright 2025 The ACESTEO Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from ..core.attention.attention import attention_forward
from ..core import gradient_checkpoint_forward
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutput
from transformers.processing_utils import Unpack
from transformers.utils import logging
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3MLP,
Qwen3RMSNorm,
Qwen3RotaryEmbedding,
apply_rotary_pos_emb,
)
logger = logging.get_logger(__name__)
def create_4d_mask(
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None, # [Batch, Seq_Len]
sliding_window: Optional[int] = None,
is_sliding_window: bool = False,
is_causal: bool = True,
) -> torch.Tensor:
"""
General 4D Attention Mask generator compatible with CPU/Mac/SDPA and Eager mode.
Supports use cases:
1. Causal Full: is_causal=True, is_sliding_window=False (standard GPT)
2. Causal Sliding: is_causal=True, is_sliding_window=True (Mistral/Qwen local window)
3. Bidirectional Full: is_causal=False, is_sliding_window=False (BERT/Encoder)
4. Bidirectional Sliding: is_causal=False, is_sliding_window=True (Longformer local)
Returns:
[Batch, 1, Seq_Len, Seq_Len] additive mask (0.0 for keep, -inf for mask)
"""
# ------------------------------------------------------
# 1. Construct basic geometry mask [Seq_Len, Seq_Len]
# ------------------------------------------------------
# Build index matrices
# i (Query): [0, 1, ..., L-1]
# j (Key): [0, 1, ..., L-1]
indices = torch.arange(seq_len, device=device)
# diff = i - j
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
# Initialize all True (all positions visible)
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
# (A) Handle causality (Causal)
if is_causal:
# i >= j => diff >= 0
valid_mask = valid_mask & (diff >= 0)
# (B) Handle sliding window
if is_sliding_window and sliding_window is not None:
if is_causal:
# Causal sliding: only attend to past window steps
# i - j <= window => diff <= window
# (diff >= 0 already handled above)
valid_mask = valid_mask & (diff <= sliding_window)
else:
# Bidirectional sliding: attend past and future window steps
# |i - j| <= window => abs(diff) <= sliding_window
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
# Expand dimensions to [1, 1, Seq_Len, Seq_Len] for broadcasting
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
# ------------------------------------------------------
# 2. Apply padding mask (Key Masking)
# ------------------------------------------------------
if attention_mask is not None:
# attention_mask shape: [Batch, Seq_Len] (1=valid, 0=padding)
# We want to mask out invalid keys (columns)
# Expand shape: [Batch, 1, 1, Seq_Len]
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
# Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
# Result shape: [B, 1, L, L]
valid_mask = valid_mask & padding_mask_4d
# ------------------------------------------------------
# 3. Convert to additive mask
# ------------------------------------------------------
# Get the minimal value for current dtype
min_dtype = torch.finfo(dtype).min
# Create result tensor filled with -inf by default
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
# Set valid positions to 0.0
mask_tensor.masked_fill_(valid_mask, 0.0)
return mask_tensor
def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
"""
Pack two sequences by concatenating and sorting them based on mask values.
Args:
hidden1: First hidden states tensor of shape [B, L1, D]
hidden2: Second hidden states tensor of shape [B, L2, D]
mask1: First mask tensor of shape [B, L1]
mask2: Second mask tensor of shape [B, L2]
Returns:
Tuple of (packed_hidden_states, new_mask) where:
- packed_hidden_states: Packed hidden states with valid tokens (mask=1) first, shape [B, L1+L2, D]
- new_mask: New mask tensor indicating valid positions, shape [B, L1+L2]
"""
# Step 1: Concatenate hidden states and masks along sequence dimension
hidden_cat = torch.cat([hidden1, hidden2], dim=1) # [B, L, D]
mask_cat = torch.cat([mask1, mask2], dim=1) # [B, L]
B, L, D = hidden_cat.shape
# Step 2: Sort indices so that mask values of 1 come before 0
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) # [B, L]
# Step 3: Reorder hidden states using sorted indices
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
# Step 4: Create new mask based on valid sequence lengths
lengths = mask_cat.sum(dim=1) # [B]
new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
return hidden_left, new_mask
class TimestepEmbedding(nn.Module):
"""
Timestep embedding module for diffusion models.
Converts timestep values into high-dimensional embeddings using sinusoidal
positional encoding, followed by MLP layers. Used for conditioning diffusion
models on timestep information.
"""
def __init__(
self,
in_channels: int,
time_embed_dim: int,
scale: float = 1,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act1 = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
self.in_channels = in_channels
self.act2 = nn.SiLU()
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
self.scale = scale
def timestep_embedding(self, t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t: A 1-D tensor of N indices, one per batch element. These may be fractional.
dim: The dimension of the output embeddings.
max_period: Controls the minimum frequency of the embeddings.
Returns:
An (N, D) tensor of positional embeddings.
"""
t = t * self.scale
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.in_channels)
temb = self.linear_1(t_freq.to(t.dtype))
temb = self.act1(temb)
temb = self.linear_2(temb)
timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1))
return temb, timestep_proj
class AceStepAttention(nn.Module):
"""
Multi-headed attention module for AceStep model.
Implements the attention mechanism from 'Attention Is All You Need' paper,
with support for both self-attention and cross-attention modes. Uses RMSNorm
for query and key normalization, and supports sliding window attention for
efficient long-sequence processing.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
is_cross_attention: bool = False,
is_causal: bool = False,
):
super().__init__()
self.layer_idx = layer_idx
self.head_dim = head_dim or hidden_size // num_attention_heads
self.num_key_value_groups = num_attention_heads // num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.attention_dropout = attention_dropout
if is_cross_attention:
is_causal = False
self.is_causal = is_causal
self.is_cross_attention = is_cross_attention
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.attention_type = layer_types[layer_idx]
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
# Project and normalize query states
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
# Determine if this is cross-attention (requires encoder_hidden_states)
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
# Cross-attention path: attend to encoder hidden states
if is_cross_attention:
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
# After the first generated token, we can reuse all key/value states from cache
curr_past_key_value = past_key_value.cross_attention_cache
# Conditions for calculating key and value states
if not is_updated:
# Compute and cache K/V for the first time
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
# Update cache: save all key/value states to cache for fast auto-regressive generation
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
# Set flag that this layer's cross-attention cache is updated
past_key_value.is_updated[self.layer_idx] = True
else:
# Reuse cached key/value states for subsequent tokens
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
# No cache used, compute K/V directly
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
# Self-attention path: attend to the same sequence
else:
# Project and normalize key/value states for self-attention
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
# Apply rotary position embeddings (RoPE) if provided
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Update cache for auto-regressive generation
if past_key_value is not None:
# Sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# GGA expansion: if num_key_value_heads < num_attention_heads
if self.num_key_value_groups > 1:
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
# Use DiffSynth unified attention
# Tensors are already in (batch, heads, seq, dim) format -> "b n s d"
attn_output = attention_forward(
query_states, key_states, value_states,
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
attn_mask=attention_mask,
)
attn_weights = None # attention_forward doesn't return weights
# Flatten and project output: (B, n_heads, seq, dim) -> (B, seq, n_heads*dim)
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class AceStepEncoderLayer(nn.Module):
"""
Encoder layer for AceStep model.
Consists of self-attention and MLP (feed-forward) sub-layers with residual connections.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
intermediate_size: int = 6144,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: list = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.layer_idx = layer_idx
self.self_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
is_cross_attention=False,
is_causal=False,
)
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# MLP (feed-forward) sub-layer
self.mlp = Qwen3MLP(
config=type('Config', (), {
'hidden_size': hidden_size,
'intermediate_size': intermediate_size,
'hidden_act': 'silu',
})()
)
self.attention_type = layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[
torch.FloatTensor,
Optional[tuple[torch.FloatTensor, torch.FloatTensor]],
]:
# Self-attention with residual connection
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
# Encoders don't use cache
use_cache=False,
past_key_value=None,
**kwargs,
)
hidden_states = residual + hidden_states
# MLP with residual connection
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class AceStepDiTLayer(nn.Module):
"""
DiT (Diffusion Transformer) layer for AceStep model.
Implements a transformer layer with three main components:
1. Self-attention with adaptive layer norm (AdaLN)
2. Cross-attention (optional) for conditioning on encoder outputs
3. Feed-forward MLP with adaptive layer norm
Uses scale-shift modulation from timestep embeddings for adaptive normalization.
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
intermediate_size: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
use_cross_attention: bool = True,
):
super().__init__()
self.self_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.self_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
self.use_cross_attention = use_cross_attention
if self.use_cross_attention:
self.cross_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.cross_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
is_cross_attention=True,
)
self.mlp_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.mlp = Qwen3MLP(
config=type('Config', (), {
'hidden_size': hidden_size,
'intermediate_size': intermediate_size,
'hidden_act': 'silu',
})()
)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, hidden_size) / hidden_size**0.5)
self.attention_type = layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# Extract scale-shift parameters for adaptive layer norm from timestep embeddings
# 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table.to(temb.device) + temb
).chunk(6, dim=1)
# Step 1: Self-attention with adaptive layer norm (AdaLN)
# Apply adaptive normalization: norm(x) * (1 + scale) + shift
norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output, self_attn_weights = self.self_attn(
hidden_states=norm_hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=False,
past_key_value=None,
**kwargs,
)
# Apply gated residual connection: x = x + attn_output * gate
hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)
# Step 2: Cross-attention (if enabled) for conditioning on encoder outputs
if self.use_cross_attention:
norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states)
attn_output, cross_attn_weights = self.cross_attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
# Standard residual connection for cross-attention
hidden_states = hidden_states + attn_output
# Step 3: Feed-forward (MLP) with adaptive layer norm
# Apply adaptive normalization for MLP: norm(x) * (1 + scale) + shift
norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
ff_output = self.mlp(norm_hidden_states)
# Apply gated residual connection: x = x + mlp_output * gate
hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
class Lambda(nn.Module):
"""
Wrapper module for arbitrary lambda functions.
Allows using lambda functions in nn.Sequential by wrapping them in a Module.
Useful for simple transformations like transpose operations.
"""
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
class AceStepDiTModel(nn.Module):
"""
DiT (Diffusion Transformer) model for AceStep.
Main diffusion model that generates audio latents conditioned on text, lyrics,
and timbre. Uses patch-based processing with transformer layers, timestep
conditioning, and cross-attention to encoder outputs.
"""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_hidden_layers: int = 24,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
use_cache: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
patch_size: int = 2,
in_channels: int = 192,
audio_acoustic_hidden_dim: int = 64,
encoder_hidden_size: Optional[int] = None,
**kwargs,
):
super().__init__()
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.use_cache = use_cache
encoder_hidden_size = encoder_hidden_size or hidden_size
# Rotary position embeddings for transformer layers
rope_config = type('RopeConfig', (), {
'hidden_size': hidden_size,
'num_attention_heads': num_attention_heads,
'num_key_value_heads': num_key_value_heads,
'head_dim': head_dim,
'max_position_embeddings': max_position_embeddings,
'rope_theta': rope_theta,
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
'rms_norm_eps': rms_norm_eps,
'attention_bias': attention_bias,
'attention_dropout': attention_dropout,
'hidden_act': 'silu',
'intermediate_size': intermediate_size,
'layer_types': self.layer_types,
'sliding_window': sliding_window,
})()
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
# Stack of DiT transformer layers
self.layers = nn.ModuleList([
AceStepDiTLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
intermediate_size=intermediate_size,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=self.layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
for layer_idx in range(num_hidden_layers)
])
self.patch_size = patch_size
# Input projection: patch embedding using 1D convolution
self.proj_in = nn.Sequential(
Lambda(lambda x: x.transpose(1, 2)),
nn.Conv1d(
in_channels=in_channels,
out_channels=hidden_size,
kernel_size=patch_size,
stride=patch_size,
padding=0,
),
Lambda(lambda x: x.transpose(1, 2)),
)
# Timestep embeddings for diffusion conditioning
self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
# Project encoder hidden states to model dimension
self.condition_embedder = nn.Linear(encoder_hidden_size, hidden_size, bias=True)
# Output normalization and projection
self.norm_out = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.proj_out = nn.Sequential(
Lambda(lambda x: x.transpose(1, 2)),
nn.ConvTranspose1d(
in_channels=hidden_size,
out_channels=audio_acoustic_hidden_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0,
),
Lambda(lambda x: x.transpose(1, 2)),
)
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, hidden_size) / hidden_size**0.5)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
timestep_r: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
context_latents: torch.Tensor,
use_cache: Optional[bool] = False,
past_key_values: Optional[EncoderDecoderCache] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
return_hidden_states: int = None,
custom_layers_config: Optional[dict] = None,
enable_early_exit: bool = False,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
):
use_cache = use_cache if use_cache is not None else self.use_cache
# Disable cache during training or when gradient checkpointing is enabled
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if self.training:
use_cache = False
# Initialize cache if needed (only during inference for auto-regressive generation)
if not self.training and use_cache and past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
# Compute timestep embeddings for diffusion conditioning
# Two embeddings: one for timestep t, one for timestep difference (t - r)
temb_t, timestep_proj_t = self.time_embed(timestep)
temb_r, timestep_proj_r = self.time_embed_r(timestep - timestep_r)
# Combine embeddings
temb = temb_t + temb_r
timestep_proj = timestep_proj_t + timestep_proj_r
# Concatenate context latents (source latents + chunk masks) with hidden states
hidden_states = torch.cat([context_latents, hidden_states], dim=-1)
# Record original sequence length for later restoration after padding
original_seq_len = hidden_states.shape[1]
# Apply padding if sequence length is not divisible by patch_size
# This ensures proper patch extraction
pad_length = 0
if hidden_states.shape[1] % self.patch_size != 0:
pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size)
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length), mode='constant', value=0)
# Project input to patches and project encoder states
hidden_states = self.proj_in(hidden_states)
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
# Cache positions
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
)
# Position IDs
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
seq_len = hidden_states.shape[1]
encoder_seq_len = encoder_hidden_states.shape[1]
dtype = hidden_states.dtype
device = hidden_states.device
# Initialize Mask variables
full_attn_mask = None
sliding_attn_mask = None
encoder_attn_mask = None
decoder_attn_mask = None
# Target library discards the passed-in attention_mask for 4D mask
# construction (line 1384: attention_mask = None)
attention_mask = None
# 1. Full Attention (Bidirectional, Global)
full_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask,
sliding_window=None,
is_sliding_window=False,
is_causal=False
)
max_len = max(seq_len, encoder_seq_len)
encoder_attn_mask = create_4d_mask(
seq_len=max_len,
dtype=dtype,
device=device,
attention_mask=attention_mask,
sliding_window=None,
is_sliding_window=False,
is_causal=False
)
encoder_attn_mask = encoder_attn_mask[:, :, :seq_len, :encoder_seq_len]
# 2. Sliding Attention (Bidirectional, Local)
if self.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len,
dtype=dtype,
device=device,
attention_mask=attention_mask,
sliding_window=self.sliding_window,
is_sliding_window=True,
is_causal=False
)
# Build mask mapping
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
"encoder_attention_mask": encoder_attn_mask,
}
# Create position embeddings to be shared across all decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
all_cross_attentions = () if output_attentions else None
# Handle early exit for custom layer configurations
max_needed_layer = float('inf')
if custom_layers_config is not None and enable_early_exit:
max_needed_layer = max(custom_layers_config.keys())
output_attentions = True
if all_cross_attentions is None:
all_cross_attentions = ()
# Process through transformer layers
for index_block, layer_module in enumerate(self.layers):
# Early exit optimization
if index_block > max_needed_layer:
break
# Prepare layer arguments
layer_args = (
hidden_states,
position_embeddings,
timestep_proj,
self_attn_mask_mapping[layer_module.attention_type],
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
encoder_hidden_states,
self_attn_mask_mapping["encoder_attention_mask"],
)
layer_kwargs = flash_attn_kwargs
# Use gradient checkpointing if enabled
layer_outputs = gradient_checkpoint_forward(
layer_module,
use_gradient_checkpointing,
use_gradient_checkpointing_offload,
*layer_args,
**layer_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions and self.layers[index_block].use_cross_attention:
# layer_outputs structure: (hidden_states, self_attn_weights, cross_attn_weights)
if len(layer_outputs) >= 3:
all_cross_attentions += (layer_outputs[2],)
if return_hidden_states:
return hidden_states
# Extract scale-shift parameters for adaptive output normalization
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)
# Apply adaptive layer norm: norm(x) * (1 + scale) + shift
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
# Project output: de-patchify back to original sequence format
hidden_states = self.proj_out(hidden_states)
# Crop back to original sequence length to ensure exact length match (remove padding)
hidden_states = hidden_states[:, :original_seq_len, :]
outputs = (hidden_states, past_key_values)
if output_attentions:
outputs += (all_cross_attentions,)
return outputs

View File

@@ -0,0 +1,53 @@
import torch
class AceStepTextEncoder(torch.nn.Module):
def __init__(
self,
):
super().__init__()
from transformers import Qwen3Config, Qwen3Model
config = Qwen3Config(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=151643,
dtype="bfloat16",
eos_token_id=151643,
head_dim=128,
hidden_act="silu",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=3072,
layer_types=["full_attention"] * 28,
max_position_embeddings=32768,
max_window_layers=28,
model_type="qwen3",
num_attention_heads=16,
num_hidden_layers=28,
num_key_value_heads=8,
pad_token_id=151643,
rms_norm_eps=1e-06,
rope_scaling=None,
rope_theta=1000000,
sliding_window=None,
tie_word_embeddings=True,
use_cache=True,
use_sliding_window=False,
vocab_size=151669,
)
self.model = Qwen3Model(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.Tensor,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
return outputs.last_hidden_state

View File

@@ -0,0 +1,732 @@
# Copyright 2025 The ACESTEO Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ACE-Step Audio Tokenizer — VAE latent discretization pathway.
Contains:
- AceStepAudioTokenizer: continuous VAE latent → discrete FSQ tokens
- AudioTokenDetokenizer: discrete tokens → continuous VAE-latent-shaped features
Only used in cover song mode (is_covers=True). Bypassed in text-to-music.
"""
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from ..core.attention import attention_forward
from ..core.gradient import gradient_checkpoint_forward
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutput
from transformers.processing_utils import Unpack
from transformers.utils import can_return_tuple, logging
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3MLP,
Qwen3RMSNorm,
Qwen3RotaryEmbedding,
apply_rotary_pos_emb,
)
from vector_quantize_pytorch import ResidualFSQ
logger = logging.get_logger(__name__)
def create_4d_mask(
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None,
sliding_window: Optional[int] = None,
is_sliding_window: bool = False,
is_causal: bool = True,
) -> torch.Tensor:
indices = torch.arange(seq_len, device=device)
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
if is_causal:
valid_mask = valid_mask & (diff >= 0)
if is_sliding_window and sliding_window is not None:
if is_causal:
valid_mask = valid_mask & (diff <= sliding_window)
else:
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
valid_mask = valid_mask & padding_mask_4d
min_dtype = torch.finfo(dtype).min
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
mask_tensor.masked_fill_(valid_mask, 0.0)
return mask_tensor
class Lambda(nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
class AceStepAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
is_cross_attention: bool = False,
is_causal: bool = False,
):
super().__init__()
self.layer_idx = layer_idx
self.head_dim = head_dim or hidden_size // num_attention_heads
self.num_key_value_groups = num_attention_heads // num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.attention_dropout = attention_dropout
if is_cross_attention:
is_causal = False
self.is_causal = is_causal
self.is_cross_attention = is_cross_attention
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
self.attention_type = layer_types[layer_idx]
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
if is_cross_attention:
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
curr_past_key_value = past_key_value.cross_attention_cache
if not is_updated:
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
past_key_value.is_updated[self.layer_idx] = True
else:
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
else:
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if self.num_key_value_groups > 1:
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
attn_output = attention_forward(
query_states, key_states, value_states,
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
attn_mask=attention_mask,
)
attn_weights = None
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class AceStepEncoderLayer(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
rms_norm_eps: float,
attention_bias: bool,
attention_dropout: float,
layer_types: list,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = None,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.layer_idx = layer_idx
self.self_attn = AceStepAttention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
is_cross_attention=False,
is_causal=False,
)
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
mlp_config = type('Config', (), {
'hidden_size': hidden_size,
'intermediate_size': intermediate_size,
'hidden_act': 'silu',
})()
self.mlp = Qwen3MLP(mlp_config)
self.attention_type = layer_types[layer_idx]
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=False,
past_key_value=None,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class AttentionPooler(nn.Module):
"""Pools every pool_window_size frames into 1 representation via transformer + CLS token."""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
num_attention_pooler_hidden_layers: int = 2,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Default matches target library config (24 alternating entries).
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# Slice layer_types to our own layer count
pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
rope_config = type('RopeConfig', (), {
'hidden_size': hidden_size,
'num_attention_heads': num_attention_heads,
'num_key_value_heads': num_key_value_heads,
'head_dim': head_dim,
'max_position_embeddings': max_position_embeddings,
'rope_theta': rope_theta,
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
'rms_norm_eps': rms_norm_eps,
'attention_bias': attention_bias,
'attention_dropout': attention_dropout,
'hidden_act': 'silu',
'intermediate_size': intermediate_size,
'layer_types': pooler_layer_types,
'sliding_window': sliding_window,
'_attn_implementation': self._attn_implementation,
})()
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
self.gradient_checkpointing = False
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
self.layers = nn.ModuleList([
AceStepEncoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=pooler_layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
for layer_idx in range(num_attention_pooler_hidden_layers)
])
@can_return_tuple
def forward(
self,
x,
attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> torch.Tensor:
B, T, P, D = x.shape
x = self.embed_tokens(x)
special_tokens = self.special_token.expand(B, T, 1, -1).to(x.device)
x = torch.cat([special_tokens, x], dim=2)
x = rearrange(x, "b t p c -> (b t) p c")
cache_position = torch.arange(0, x.shape[1], device=x.device)
position_ids = cache_position.unsqueeze(0)
hidden_states = x
position_embeddings = self.rotary_emb(hidden_states, position_ids)
seq_len = x.shape[1]
dtype = x.dtype
device = x.device
full_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=None,
is_sliding_window=False, is_causal=False
)
sliding_attn_mask = None
if self.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=self.sliding_window,
is_sliding_window=True, is_causal=False
)
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
for layer_module in self.layers:
layer_outputs = layer_module(
hidden_states, position_embeddings,
attention_mask=self_attn_mask_mapping[layer_module.attention_type],
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
cls_output = hidden_states[:, 0, :]
return rearrange(cls_output, "(b t) c -> b t c", b=B)
class AceStepAudioTokenizer(nn.Module):
"""Converts continuous acoustic features (VAE latents) into discrete quantized tokens.
Input: [B, T, 64] (VAE latent dim)
Output: quantized [B, T/5, 2048], indices [B, T/5, 1]
"""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
audio_acoustic_hidden_dim: int = 64,
pool_window_size: int = 5,
fsq_dim: int = 2048,
fsq_input_levels: list = None,
fsq_input_num_quantizers: int = 1,
num_attention_pooler_hidden_layers: int = 2,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Default matches target library config (24 alternating entries).
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
self.pool_window_size = pool_window_size
self.fsq_dim = fsq_dim
self.fsq_input_levels = fsq_input_levels or [8, 8, 8, 5, 5, 5]
self.fsq_input_num_quantizers = fsq_input_num_quantizers
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.audio_acoustic_proj = nn.Linear(audio_acoustic_hidden_dim, hidden_size)
# Slice layer_types for the attention pooler
pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
self.attention_pooler = AttentionPooler(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=pooler_layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
use_sliding_window=use_sliding_window,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
)
self.quantizer = ResidualFSQ(
dim=self.fsq_dim,
levels=self.fsq_input_levels,
num_quantizers=self.fsq_input_num_quantizers,
force_quantization_f32=False, # avoid autocast bug in vector_quantize_pytorch
)
@can_return_tuple
def forward(
self,
hidden_states: Optional[torch.FloatTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.audio_acoustic_proj(hidden_states)
hidden_states = self.attention_pooler(hidden_states)
quantized, indices = self.quantizer(hidden_states)
return quantized, indices
def tokenize(self, x):
"""Convenience: takes [B, T, 64], rearranges to patches, runs forward."""
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=self.pool_window_size)
return self.forward(x)
class AudioTokenDetokenizer(nn.Module):
"""Converts quantized audio tokens back to continuous acoustic representations.
Input: [B, T/5, hidden_size] (quantized vectors)
Output: [B, T, 64] (VAE-latent-shaped continuous features)
"""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
pool_window_size: int = 5,
audio_acoustic_hidden_dim: int = 64,
num_attention_pooler_hidden_layers: int = 2,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Default matches target library config (24 alternating entries).
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
self.head_dim = head_dim or hidden_size // num_attention_heads
self.sliding_window = sliding_window
self.use_sliding_window = use_sliding_window
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.pool_window_size = pool_window_size
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# Slice layer_types to our own layer count (use num_audio_decoder_hidden_layers)
detok_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
rope_config = type('RopeConfig', (), {
'hidden_size': hidden_size,
'num_attention_heads': num_attention_heads,
'num_key_value_heads': num_key_value_heads,
'head_dim': head_dim,
'max_position_embeddings': max_position_embeddings,
'rope_theta': rope_theta,
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
'rms_norm_eps': rms_norm_eps,
'attention_bias': attention_bias,
'attention_dropout': attention_dropout,
'hidden_act': 'silu',
'intermediate_size': intermediate_size,
'layer_types': detok_layer_types,
'sliding_window': sliding_window,
'_attn_implementation': self._attn_implementation,
})()
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
self.gradient_checkpointing = False
self.special_tokens = nn.Parameter(torch.randn(1, pool_window_size, hidden_size) * 0.02)
self.layers = nn.ModuleList([
AceStepEncoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=detok_layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
layer_idx=layer_idx,
)
for layer_idx in range(num_attention_pooler_hidden_layers)
])
self.proj_out = nn.Linear(hidden_size, audio_acoustic_hidden_dim)
@can_return_tuple
def forward(
self,
x,
attention_mask: Optional[torch.Tensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> torch.Tensor:
B, T, D = x.shape
x = self.embed_tokens(x)
x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
special_tokens = self.special_tokens.expand(B, T, -1, -1)
x = x + special_tokens.to(x.device)
x = rearrange(x, "b t p c -> (b t) p c")
cache_position = torch.arange(0, x.shape[1], device=x.device)
position_ids = cache_position.unsqueeze(0)
hidden_states = x
position_embeddings = self.rotary_emb(hidden_states, position_ids)
seq_len = x.shape[1]
dtype = x.dtype
device = x.device
full_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=None,
is_sliding_window=False, is_causal=False
)
sliding_attn_mask = None
if self.use_sliding_window:
sliding_attn_mask = create_4d_mask(
seq_len=seq_len, dtype=dtype, device=device,
attention_mask=attention_mask, sliding_window=self.sliding_window,
is_sliding_window=True, is_causal=False
)
self_attn_mask_mapping = {
"full_attention": full_attn_mask,
"sliding_attention": sliding_attn_mask,
}
for layer_module in self.layers:
layer_outputs = layer_module(
hidden_states, position_embeddings,
attention_mask=self_attn_mask_mapping[layer_module.attention_type],
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_out(hidden_states)
return rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.pool_window_size)
class AceStepTokenizer(nn.Module):
"""Container for AceStepAudioTokenizer + AudioTokenDetokenizer.
Provides encode/decode convenience methods for VAE latent discretization.
Used in cover song mode to convert source audio latents to discrete tokens
and back to continuous conditioning hints.
"""
def __init__(
self,
hidden_size: int = 2048,
intermediate_size: int = 6144,
num_attention_heads: int = 16,
num_key_value_heads: int = 8,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
attention_dropout: float = 0.0,
layer_types: Optional[list] = None,
head_dim: Optional[int] = None,
sliding_window: Optional[int] = 128,
use_sliding_window: bool = True,
rope_theta: float = 1000000,
max_position_embeddings: int = 32768,
initializer_range: float = 0.02,
audio_acoustic_hidden_dim: int = 64,
pool_window_size: int = 5,
fsq_dim: int = 2048,
fsq_input_levels: list = None,
fsq_input_num_quantizers: int = 1,
num_attention_pooler_hidden_layers: int = 2,
num_audio_decoder_hidden_layers: int = 24,
**kwargs,
):
super().__init__()
# Default layer_types matches target library config (24 alternating entries).
# Sub-modules (pooler/detokenizer) slice first N entries for their own layer count.
if layer_types is None:
layer_types = ["sliding_attention", "full_attention"] * 12
self.tokenizer = AceStepAudioTokenizer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
use_sliding_window=use_sliding_window,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
audio_acoustic_hidden_dim=audio_acoustic_hidden_dim,
pool_window_size=pool_window_size,
fsq_dim=fsq_dim,
fsq_input_levels=fsq_input_levels,
fsq_input_num_quantizers=fsq_input_num_quantizers,
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
**kwargs,
)
self.detokenizer = AudioTokenDetokenizer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
layer_types=layer_types,
head_dim=head_dim,
sliding_window=sliding_window,
use_sliding_window=use_sliding_window,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
pool_window_size=pool_window_size,
audio_acoustic_hidden_dim=audio_acoustic_hidden_dim,
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
**kwargs,
)
def encode(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""VAE latent [B, T, 64] → discrete tokens."""
return self.tokenizer(hidden_states)
def decode(self, quantized: torch.Tensor) -> torch.Tensor:
"""Discrete tokens [B, T/5, hidden_size] → continuous [B, T, 64]."""
return self.detokenizer(quantized)
def tokenize(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convenience: [B, T, 64] → quantized + indices via patch rearrangement."""
return self.tokenizer.tokenize(x)

View File

@@ -0,0 +1,287 @@
# Copyright 2025 The ACESTEO Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ACE-Step Audio VAE (AutoencoderOobleck CNN architecture).
This is a CNN-based VAE for audio waveform encoding/decoding.
It uses weight-normalized convolutions and Snake1d activations.
Does NOT depend on diffusers — pure nn.Module implementation.
"""
import math
from typing import Optional
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm, remove_weight_norm
class Snake1d(nn.Module):
"""Snake activation: x + 1/(beta+eps) * sin(alpha*x)^2."""
def __init__(self, hidden_dim: int, logscale: bool = True):
super().__init__()
self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.logscale = logscale
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
shape = hidden_states.shape
alpha = torch.exp(self.alpha) if self.logscale else self.alpha
beta = torch.exp(self.beta) if self.logscale else self.beta
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
return hidden_states.reshape(shape)
class OobleckResidualUnit(nn.Module):
"""Residual unit: Snake1d → Conv1d(dilated) → Snake1d → Conv1d(1×1) + skip."""
def __init__(self, dimension: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.snake1 = Snake1d(dimension)
self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
self.snake2 = Snake1d(dimension)
self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
output = self.conv1(self.snake1(hidden_state))
output = self.conv2(self.snake2(output))
padding = (hidden_state.shape[-1] - output.shape[-1]) // 2
if padding > 0:
hidden_state = hidden_state[..., padding:-padding]
return hidden_state + output
class OobleckEncoderBlock(nn.Module):
"""Encoder block: 3 residual units + downsampling conv."""
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
super().__init__()
self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
self.snake1 = Snake1d(input_dim)
self.conv1 = weight_norm(
nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.snake1(self.res_unit3(hidden_state))
return self.conv1(hidden_state)
class OobleckDecoderBlock(nn.Module):
"""Decoder block: upsampling conv + 3 residual units."""
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
super().__init__()
self.snake1 = Snake1d(input_dim)
self.conv_t1 = weight_norm(
nn.ConvTranspose1d(
input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2),
)
)
self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv_t1(hidden_state)
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
return self.res_unit3(hidden_state)
class OobleckEncoder(nn.Module):
"""Full encoder: audio → latent representation [B, encoder_hidden_size, T'].
conv1 → [blocks] → snake1 → conv2
"""
def __init__(
self,
encoder_hidden_size: int = 128,
audio_channels: int = 2,
downsampling_ratios: list = None,
channel_multiples: list = None,
):
super().__init__()
downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10]
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
channel_multiples = [1] + channel_multiples
self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
self.block = nn.ModuleList()
for stride_index, stride in enumerate(downsampling_ratios):
self.block.append(
OobleckEncoderBlock(
input_dim=encoder_hidden_size * channel_multiples[stride_index],
output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
stride=stride,
)
)
d_model = encoder_hidden_size * channel_multiples[-1]
self.snake1 = Snake1d(d_model)
self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.conv1(hidden_state)
for block in self.block:
hidden_state = block(hidden_state)
hidden_state = self.snake1(hidden_state)
return self.conv2(hidden_state)
class OobleckDecoder(nn.Module):
"""Full decoder: latent → audio waveform [B, audio_channels, T].
conv1 → [blocks] → snake1 → conv2(no bias)
"""
def __init__(
self,
channels: int = 128,
input_channels: int = 64,
audio_channels: int = 2,
upsampling_ratios: list = None,
channel_multiples: list = None,
):
super().__init__()
upsampling_ratios = upsampling_ratios or [10, 6, 4, 4, 2]
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
channel_multiples = [1] + channel_multiples
self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
self.block = nn.ModuleList()
for stride_index, stride in enumerate(upsampling_ratios):
self.block.append(
OobleckDecoderBlock(
input_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index],
output_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index - 1],
stride=stride,
)
)
self.snake1 = Snake1d(channels)
# conv2 has no bias (matches checkpoint: only weight_g/weight_v, no bias key)
self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.conv1(hidden_state)
for block in self.block:
hidden_state = block(hidden_state)
hidden_state = self.snake1(hidden_state)
return self.conv2(hidden_state)
class OobleckDiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.scale = parameters.chunk(2, dim=1)
self.std = nn.functional.softplus(self.scale) + 1e-4
self.var = self.std * self.std
self.logvar = torch.log(self.var)
self.deterministic = deterministic
def sample(self, generator: torch.Generator | None = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = torch.randn(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
else:
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
var_ratio = self.var / other.var
logvar_diff = self.logvar - other.logvar
kl = normalized_diff + var_ratio + logvar_diff - 1
kl = kl.sum(1).mean()
return kl
class AceStepVAE(nn.Module):
"""Audio VAE for ACE-Step (AutoencoderOobleck architecture).
Encodes audio waveform → latent, decodes latent → audio waveform.
Uses Snake1d activations and weight-normalized convolutions.
"""
def __init__(
self,
encoder_hidden_size: int = 128,
downsampling_ratios: list = None,
channel_multiples: list = None,
decoder_channels: int = 128,
decoder_input_channels: int = 64,
audio_channels: int = 2,
sampling_rate: int = 48000,
):
super().__init__()
downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10]
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
upsampling_ratios = downsampling_ratios[::-1]
self.encoder = OobleckEncoder(
encoder_hidden_size=encoder_hidden_size,
audio_channels=audio_channels,
downsampling_ratios=downsampling_ratios,
channel_multiples=channel_multiples,
)
self.decoder = OobleckDecoder(
channels=decoder_channels,
input_channels=decoder_input_channels,
audio_channels=audio_channels,
upsampling_ratios=upsampling_ratios,
channel_multiples=channel_multiples,
)
self.sampling_rate = sampling_rate
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Audio waveform [B, audio_channels, T] → latent [B, decoder_input_channels, T']."""
h = self.encoder(x)
output = OobleckDiagonalGaussianDistribution(h).sample()
return output
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Latent [B, decoder_input_channels, T] → audio waveform [B, audio_channels, T']."""
return self.decoder(z)
def forward(self, sample: torch.Tensor) -> torch.Tensor:
"""Full round-trip: encode → decode."""
z = self.encode(sample)
return self.decode(z)
def remove_weight_norm(self):
"""Remove weight normalization from all conv layers (for export/inference)."""
for module in self.modules():
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
remove_weight_norm(module)

View File

@@ -1,78 +0,0 @@
import torch
class SDTextEncoder(torch.nn.Module):
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
vocab_size=49408,
layer_norm_eps=1e-05,
hidden_act="quick_gelu",
initializer_factor=1.0,
initializer_range=0.02,
bos_token_id=0,
eos_token_id=2,
pad_token_id=1,
projection_dim=768,
):
super().__init__()
from transformers import CLIPConfig, CLIPTextModel
config = CLIPConfig(
text_config={
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"num_hidden_layers": num_hidden_layers,
"num_attention_heads": num_attention_heads,
"max_position_embeddings": max_position_embeddings,
"vocab_size": vocab_size,
"layer_norm_eps": layer_norm_eps,
"hidden_act": hidden_act,
"initializer_factor": initializer_factor,
"initializer_range": initializer_range,
"bos_token_id": bos_token_id,
"eos_token_id": eos_token_id,
"pad_token_id": pad_token_id,
"projection_dim": projection_dim,
"dropout": 0.0,
},
vision_config={
"hidden_size": hidden_size,
"intermediate_size": intermediate_size,
"num_hidden_layers": num_hidden_layers,
"num_attention_heads": num_attention_heads,
"max_position_embeddings": max_position_embeddings,
"layer_norm_eps": layer_norm_eps,
"hidden_act": hidden_act,
"initializer_factor": initializer_factor,
"initializer_range": initializer_range,
"projection_dim": projection_dim,
},
projection_dim=projection_dim,
)
self.model = CLIPTextModel(config.text_config)
self.config = config
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_hidden_states=True,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs,
)
if output_hidden_states:
return outputs.last_hidden_state, outputs.hidden_states
return outputs.last_hidden_state

View File

@@ -1,912 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
# ===== Time Embedding =====
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.freq_shift = freq_shift
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / half_dim + self.freq_shift
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
if self.flip_sin_to_cos:
emb = torch.cat([cos_emb, sin_emb], dim=-1)
else:
emb = torch.cat([sin_emb, cos_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
out_dim = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
# ===== ResNet Blocks =====
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
# ===== Transformer Blocks =====
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, dropout=0.0):
super().__init__()
self.net = nn.ModuleList([
GEGLU(dim, dim * 4),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
])
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class Attention(nn.Module):
"""Attention block matching diffusers checkpoint key format.
Keys: to_q.weight, to_k.weight, to_v.weight, to_out.0.weight, to_out.0.bias
"""
def __init__(
self,
query_dim,
heads=8,
dim_head=64,
dropout=0.0,
bias=False,
upcast_attention=False,
cross_attention_dim=None,
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.inner_dim = inner_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([
nn.Linear(inner_dim, query_dim, bias=True),
nn.Dropout(dropout),
])
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
# Query
query = self.to_q(hidden_states)
batch_size, seq_len, _ = query.shape
# Key/Value
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
# Reshape for multi-head attention
head_dim = self.inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
# Scaled dot-product attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Reshape back
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
hidden_states = hidden_states.to(query.dtype)
# Output projection
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
cross_attention_dim=None,
upcast_attention=False,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
)
self.norm2 = nn.LayerNorm(dim)
self.attn2 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
cross_attention_dim=cross_attention_dim,
)
self.norm3 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, dropout=dropout)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
# Self-attention
attn_output = self.attn1(self.norm1(hidden_states))
hidden_states = attn_output + hidden_states
# Cross-attention
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# Feed-forward
ff_output = self.ff(self.norm3(hidden_states))
hidden_states = ff_output + hidden_states
return hidden_states
class Transformer2DModel(nn.Module):
"""2D Transformer block wrapper matching diffusers checkpoint structure.
Keys: norm.weight/bias, proj_in.weight/bias, transformer_blocks.X.*, proj_out.weight/bias
"""
def __init__(
self,
num_attention_heads=16,
attention_head_dim=64,
in_channels=320,
num_layers=1,
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
upcast_attention=False,
):
super().__init__()
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
self.proj_in = nn.Conv2d(in_channels, num_attention_heads * attention_head_dim, kernel_size=1, bias=True)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
dim=num_attention_heads * attention_head_dim,
n_heads=num_attention_heads,
d_head=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)
for _ in range(num_layers)
])
self.proj_out = nn.Conv2d(num_attention_heads * attention_head_dim, in_channels, kernel_size=1, bias=True)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch, channel, height, width = hidden_states.shape
residual = hidden_states
# Normalize and project to sequence
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
# Transformer blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
# Project back to 2D
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
# ===== Down/Up Blocks =====
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
downsample=True,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
downsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
upsample=True,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions):
# Pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
upsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet in self.resnets:
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
# ===== UNet Mid Block =====
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# There is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=in_channels // attention_head_dim,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
# ===== Downsample / Upsample =====
class Downsample2D(nn.Module):
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states, upsample_size=None):
if upsample_size is not None:
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
# ===== UNet2DConditionModel =====
class UNet2DConditionModel(nn.Module):
"""Stable Diffusion UNet with cross-attention conditioning.
state_dict keys match the diffusers UNet2DConditionModel checkpoint format.
"""
def __init__(
self,
sample_size=64,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
cross_attention_dim=768,
attention_head_dim=8,
norm_num_groups=32,
norm_eps=1e-5,
dropout=0.0,
act_fn="silu",
time_embedding_type="positional",
flip_sin_to_cos=True,
freq_shift=0,
time_embedding_dim=None,
resnet_time_scale_shift="default",
upcast_attention=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.sample_size = sample_size
# Time embedding
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
# Input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
# Down blocks
self.down_blocks = nn.ModuleList()
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if "CrossAttn" in down_block_type:
down_block = CrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
downsample=not is_final_block,
)
else:
down_block = DownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
downsample=not is_final_block,
)
self.down_blocks.append(down_block)
# Mid block
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
)
# Up blocks
self.up_blocks = nn.ModuleList()
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
# in_channels for up blocks: diffusers uses reversed_block_out_channels[min(i+1, len-1)]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
if "CrossAttn" in up_block_type:
up_block = CrossAttnUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
transformer_layers_per_block=1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
upsample=not is_final_block,
)
else:
up_block = UpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
upsample=not is_final_block,
)
self.up_blocks.append(up_block)
# Output
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
# 1. Time embedding
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
# 2. Pre-process
sample = self.conv_in(sample)
# 3. Down
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
sample, res_samples = down_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
down_block_res_samples += res_samples
# 4. Mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. Up
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-len(up_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
sample = up_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
)
# 6. Post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return sample

View File

@@ -1,642 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Optional
class DiagonalGaussianDistribution:
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# randn_like doesn't accept generator on all torch versions
sample = torch.randn(self.mean.shape, generator=generator,
device=self.parameters.device, dtype=self.parameters.dtype)
return self.mean + self.std * sample
def kl(self, other: Optional["DiagonalGaussianDistribution"] = None) -> torch.Tensor:
if self.deterministic:
return torch.tensor([0.0])
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3],
)
def mode(self) -> torch.Tensor:
return self.mean
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
else:
raise ValueError(f"Unsupported non_linearity: {non_linearity}")
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class DownEncoderBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=downsample_padding)
])
else:
self.downsamplers = None
def forward(self, hidden_states, *args, **kwargs):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class UpDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
output_scale_factor=1.0,
add_upsample=True,
temb_channels=None,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, temb=None):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class UNetMidBlock2D(nn.Module):
def __init__(
self,
in_channels,
temb_channels=None,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
add_attention=True,
attention_head_dim=1,
output_scale_factor=1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
if attention_head_dim is None:
attention_head_dim = in_channels
for _ in range(num_layers):
if self.add_attention:
attentions.append(
AttentionBlock(
in_channels,
num_groups=resnet_groups,
eps=resnet_eps,
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class AttentionBlock(nn.Module):
"""Simple attention block for VAE mid block.
Mirrors diffusers Attention class with AttnProcessor2_0 for VAE use case.
Uses modern key names (to_q, to_k, to_v, to_out) matching in-memory diffusers structure.
Checkpoint uses deprecated keys (query, key, value, proj_attn) — mapped via converter.
"""
def __init__(self, channels, num_groups=32, eps=1e-6):
super().__init__()
self.channels = channels
self.eps = eps
self.heads = 1
self.rescale_output_factor = 1.0
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=channels, eps=eps, affine=True)
self.to_q = nn.Linear(channels, channels, bias=True)
self.to_k = nn.Linear(channels, channels, bias=True)
self.to_v = nn.Linear(channels, channels, bias=True)
self.to_out = nn.ModuleList([
nn.Linear(channels, channels, bias=True),
nn.Dropout(0.0),
])
def forward(self, hidden_states):
residual = hidden_states
# Group norm
hidden_states = self.group_norm(hidden_states)
# Flatten spatial dims: (B, C, H, W) -> (B, H*W, C)
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# QKV projection
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
# Reshape for attention: (B, seq, dim) -> (B, heads, seq, head_dim)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
# Scaled dot-product attention
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Reshape back: (B, heads, seq, head_dim) -> (B, seq, heads*head_dim)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Output projection + dropout
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
# Reshape back to 4D and add residual
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
hidden_states = hidden_states + residual
# Rescale output factor
hidden_states = hidden_states / self.rescale_output_factor
return hidden_states
class Downsample2D(nn.Module):
"""Downsampling layer matching diffusers Downsample2D with use_conv=True.
Key names: conv.weight/bias.
When padding=0, applies explicit F.pad before conv to match dimension.
"""
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
import torch.nn.functional as F
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
"""Upsampling layer with key names matching diffusers checkpoint: conv.weight/bias."""
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states):
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
class Encoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.down_blocks = nn.ModuleList([])
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = DownEncoderBlock2D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_downsample=not is_final_block,
downsample_padding=0,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
def forward(self, sample):
sample = self.conv_in(sample)
for down_block in self.down_blocks:
sample = down_block(sample)
sample = self.mid_block(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class Decoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
norm_type="group",
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.up_blocks = nn.ModuleList([])
temb_channels = in_channels if norm_type == "spatial" else None
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=temb_channels,
add_attention=mid_block_add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpDecoderBlock2D(
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=self.layers_per_block + 1,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_upsample=not is_final_block,
temb_channels=temb_channels,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(self, sample, latent_embeds=None):
sample = self.conv_in(sample)
sample = self.mid_block(sample, latent_embeds)
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class StableDiffusionVAE(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
block_out_channels=(128, 256, 512, 512),
layers_per_block=2,
act_fn="silu",
latent_channels=4,
norm_num_groups=32,
sample_size=512,
scaling_factor=0.18215,
shift_factor=None,
latents_mean=None,
latents_std=None,
force_upcast=True,
use_quant_conv=True,
use_post_quant_conv=True,
mid_block_add_attention=True,
):
super().__init__()
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
double_z=True,
mid_block_add_attention=mid_block_add_attention,
)
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
mid_block_add_attention=mid_block_add_attention,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
self.latents_mean = latents_mean
self.latents_std = latents_std
self.scaling_factor = scaling_factor
self.shift_factor = shift_factor
self.sample_size = sample_size
self.force_upcast = force_upcast
def _encode(self, x):
h = self.encoder(x)
if self.quant_conv is not None:
h = self.quant_conv(h)
return h
def encode(self, x):
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
return posterior
def _decode(self, z):
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
return self.decoder(z)
def decode(self, z):
return self._decode(z)
def forward(self, sample, sample_posterior=True, return_dict=True, generator=None):
posterior = self.encode(sample)
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
# Scale latent
z = z * self.scaling_factor
decode = self.decode(z)
if return_dict:
return {"sample": decode, "posterior": posterior, "latent_sample": z}
return decode, posterior

View File

@@ -1,62 +0,0 @@
import torch
class SDXLTextEncoder2(torch.nn.Module):
def __init__(
self,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
vocab_size=49408,
layer_norm_eps=1e-05,
hidden_act="gelu",
initializer_factor=1.0,
initializer_range=0.02,
bos_token_id=0,
eos_token_id=2,
pad_token_id=1,
projection_dim=1280,
):
super().__init__()
from transformers import CLIPTextConfig, CLIPTextModelWithProjection
config = CLIPTextConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
max_position_embeddings=max_position_embeddings,
vocab_size=vocab_size,
layer_norm_eps=layer_norm_eps,
hidden_act=hidden_act,
initializer_factor=initializer_factor,
initializer_range=initializer_range,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
projection_dim=projection_dim,
)
self.model = CLIPTextModelWithProjection(config)
self.config = config
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
output_hidden_states=True,
**kwargs,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs,
)
if output_hidden_states:
return outputs.text_embeds, outputs.hidden_states
return outputs.text_embeds

View File

@@ -1,922 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
# ===== Time Embedding =====
class Timesteps(nn.Module):
def __init__(self, num_channels, flip_sin_to_cos=True, freq_shift=0):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.freq_shift = freq_shift
def forward(self, timesteps):
half_dim = self.num_channels // 2
exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / half_dim + self.freq_shift
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
sin_emb = torch.sin(emb)
cos_emb = torch.cos(emb)
if self.flip_sin_to_cos:
emb = torch.cat([cos_emb, sin_emb], dim=-1)
else:
emb = torch.cat([sin_emb, cos_emb], dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels, time_embed_dim, act_fn="silu", out_dim=None):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU() if act_fn == "silu" else nn.GELU()
out_dim = out_dim if out_dim is not None else time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, out_dim)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
# ===== ResNet Blocks =====
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
):
super().__init__()
self.pre_norm = pre_norm
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps)
self.conv1 = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = nn.Linear(temb_channels, out_channels or in_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = nn.Linear(temb_channels, 2 * (out_channels or in_channels))
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels or in_channels, eps=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels or in_channels, out_channels or in_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = nn.SiLU()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
elif non_linearity == "gelu":
self.nonlinearity = nn.GELU()
elif non_linearity == "relu":
self.nonlinearity = nn.ReLU()
self.use_conv_shortcut = conv_shortcut
self.conv_shortcut = None
if conv_shortcut:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels or in_channels, kernel_size=1, stride=1, padding=0) if in_channels != (out_channels or in_channels) else None
def forward(self, input_tensor, temb=None):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb).unsqueeze(-1).unsqueeze(-1)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
# ===== Transformer Blocks =====
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, dropout=0.0):
super().__init__()
self.net = nn.ModuleList([
GEGLU(dim, dim * 4),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim if dim_out is None else dim_out),
])
def forward(self, hidden_states):
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class Attention(nn.Module):
def __init__(
self,
query_dim,
heads=8,
dim_head=64,
dropout=0.0,
bias=False,
upcast_attention=False,
cross_attention_dim=None,
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.inner_dim = inner_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([
nn.Linear(inner_dim, query_dim, bias=True),
nn.Dropout(dropout),
])
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
query = self.to_q(hidden_states)
batch_size, seq_len, _ = query.shape
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
head_dim = self.inner_dim // self.heads
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.inner_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[1](hidden_states)
return hidden_states
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
cross_attention_dim=None,
upcast_attention=False,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
)
self.norm2 = nn.LayerNorm(dim)
self.attn2 = Attention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
bias=False,
upcast_attention=upcast_attention,
cross_attention_dim=cross_attention_dim,
)
self.norm3 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, dropout=dropout)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
attn_output = self.attn1(self.norm1(hidden_states))
hidden_states = attn_output + hidden_states
attn_output = self.attn2(self.norm2(hidden_states), encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
ff_output = self.ff(self.norm3(hidden_states))
hidden_states = ff_output + hidden_states
return hidden_states
class Transformer2DModel(nn.Module):
def __init__(
self,
num_attention_heads=16,
attention_head_dim=64,
in_channels=320,
num_layers=1,
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
upcast_attention=False,
use_linear_projection=False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.use_linear_projection = use_linear_projection
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim, bias=True)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, bias=True)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
dim=inner_dim,
n_heads=num_attention_heads,
d_head=attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)
for _ in range(num_layers)
])
if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels, bias=True)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, bias=True)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.use_linear_projection:
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
hidden_states = self.proj_in(hidden_states)
else:
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, -1, channel)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.use_linear_projection:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
else:
hidden_states = hidden_states.reshape(batch, height, width, channel).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
# ===== Down/Up Blocks =====
class CrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
downsample=True,
use_linear_projection=False,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class DownBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
downsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
in_channels_i = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels_i,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if downsample:
self.downsamplers = nn.ModuleList([
Downsample2D(out_channels, out_channels, padding=1)
])
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = []
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
output_states.append(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states.append(hidden_states)
return hidden_states, tuple(output_states)
class CrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
upsample=True,
use_linear_projection=False,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=out_channels // attention_head_dim,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet, attn in zip(self.resnets, self.attentions):
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
class UpBlock2D(nn.Module):
def __init__(
self,
in_channels,
out_channels,
prev_output_channel,
temb_channels=1280,
dropout=0.0,
num_layers=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
upsample=True,
):
super().__init__()
self.has_cross_attention = False
resnets = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if upsample:
self.upsamplers = nn.ModuleList([
Upsample2D(out_channels, out_channels)
])
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None):
for resnet in self.resnets:
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size=upsample_size)
return hidden_states
# ===== UNet Mid Block =====
class UNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels,
temb_channels=1280,
dropout=0.0,
num_layers=1,
transformer_layers_per_block=1,
resnet_eps=1e-6,
resnet_time_scale_shift="default",
resnet_act_fn="swish",
resnet_groups=32,
resnet_pre_norm=True,
cross_attention_dim=768,
attention_head_dim=1,
use_linear_projection=False,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
Transformer2DModel(
num_attention_heads=attention_head_dim,
attention_head_dim=in_channels // attention_head_dim,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
dropout=dropout,
norm_num_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
use_linear_projection=use_linear_projection,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=1.0,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
# ===== Downsample / Upsample =====
class Downsample2D(nn.Module):
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)
self.padding = padding
def forward(self, hidden_states):
if self.padding == 0:
hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
return self.conv(hidden_states)
class Upsample2D(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, hidden_states, upsample_size=None):
if upsample_size is not None:
hidden_states = F.interpolate(hidden_states, size=upsample_size, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
return self.conv(hidden_states)
# ===== SDXL UNet2DConditionModel =====
class SDXLUNet2DConditionModel(nn.Module):
def __init__(
self,
sample_size=128,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
block_out_channels=(320, 640, 1280),
layers_per_block=2,
cross_attention_dim=2048,
attention_head_dim=5,
transformer_layers_per_block=1,
norm_num_groups=32,
norm_eps=1e-5,
dropout=0.0,
act_fn="silu",
time_embedding_type="positional",
flip_sin_to_cos=True,
freq_shift=0,
time_embedding_dim=None,
resnet_time_scale_shift="default",
upcast_attention=False,
use_linear_projection=False,
addition_embed_type=None,
addition_time_embed_dim=None,
projection_class_embeddings_input_dim=None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.sample_size = sample_size
self.addition_embed_type = addition_embed_type
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
timestep_embedding_dim = time_embedding_dim or block_out_channels[0]
self.time_proj = Timesteps(timestep_embedding_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
time_embed_dim = block_out_channels[0] * 4
self.time_embedding = TimestepEmbedding(timestep_embedding_dim, time_embed_dim)
if addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
self.down_blocks = nn.ModuleList()
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
if "CrossAttn" in down_block_type:
down_block = CrossAttnDownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim[i],
downsample=not is_final_block,
use_linear_projection=use_linear_projection,
)
else:
down_block = DownBlock2D(
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
downsample=not is_final_block,
)
self.down_blocks.append(down_block)
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=1,
transformer_layers_per_block=transformer_layers_per_block[-1],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim[-1],
use_linear_projection=use_linear_projection,
)
self.up_blocks = nn.ModuleList()
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
if "CrossAttn" in up_block_type:
up_block = CrossAttnUpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_head_dim=reversed_attention_head_dim[i],
upsample=not is_final_block,
use_linear_projection=use_linear_projection,
)
else:
up_block = UpBlock2D(
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
dropout=dropout,
num_layers=layers_per_block + 1,
resnet_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
upsample=not is_final_block,
)
self.up_blocks.append(up_block)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(self, sample, timestep, encoder_hidden_states, cross_attention_kwargs=None, timestep_cond=None, added_cond_kwargs=None, return_dict=True):
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
if self.addition_embed_type == "text_time":
text_embeds = added_cond_kwargs.get("text_embeds")
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb
sample = self.conv_in(sample)
down_block_res_samples = (sample,)
for down_block in self.down_blocks:
sample, res_samples = down_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
down_block_res_samples += res_samples
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
for up_block in self.up_blocks:
res_samples = down_block_res_samples[-len(up_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(up_block.resnets)]
upsample_size = down_block_res_samples[-1].shape[2:] if down_block_res_samples else None
sample = up_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return sample

View File

@@ -0,0 +1,584 @@
"""
ACE-Step Pipeline for DiffSynth-Studio.
Text-to-Music generation pipeline using ACE-Step 1.5 model.
"""
import re, torch
from typing import Optional, Dict, Any, List, Tuple
from tqdm import tqdm
import random, math
import torch.nn.functional as F
from einops import rearrange
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.ace_step_dit import AceStepDiTModel
from ..models.ace_step_conditioner import AceStepConditionEncoder
from ..models.ace_step_text_encoder import AceStepTextEncoder
from ..models.ace_step_vae import AceStepVAE
from ..models.ace_step_tokenizer import AceStepTokenizer
class AceStepPipeline(BasePipeline):
"""Pipeline for ACE-Step text-to-music generation."""
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device,
torch_dtype=torch_dtype,
height_division_factor=1,
width_division_factor=1,
)
self.scheduler = FlowMatchScheduler("ACE-Step")
self.text_encoder: AceStepTextEncoder = None
self.conditioner: AceStepConditionEncoder = None
self.dit: AceStepDiTModel = None
self.vae: AceStepVAE = None
self.tokenizer_model: AceStepTokenizer = None
self.in_iteration_models = ("dit",)
self.units = [
AceStepUnit_TaskTypeChecker(),
AceStepUnit_PromptEmbedder(),
AceStepUnit_ReferenceAudioEmbedder(),
AceStepUnit_ContextLatentBuilder(),
AceStepUnit_ConditionEmbedder(),
AceStepUnit_NoiseInitializer(),
AceStepUnit_InputAudioEmbedder(),
]
self.model_fn = model_fn_ace_step
self.compilable_models = ["dit"]
self.sample_rate = 48000
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: str = get_device_type(),
model_configs: list[ModelConfig] = [],
text_tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
vram_limit: float = None,
):
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("ace_step_text_encoder")
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
pipe.dit = model_pool.fetch_model("ace_step_dit")
pipe.vae = model_pool.fetch_model("ace_step_vae")
pipe.vae.remove_weight_norm()
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
if text_tokenizer_config is not None:
text_tokenizer_config.download_if_necessary()
from transformers import AutoTokenizer
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
if silence_latent_config is not None:
silence_latent_config.download_if_necessary()
pipe.silence_latent = torch.load(silence_latent_config.path, weights_only=True).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
# Prompt
prompt: str,
cfg_scale: float = 1.0,
# Lyrics
lyrics: str = "",
# Task type
task_type: Optional[str] = "text2music",
# Reference audio
reference_audios: List[torch.Tensor] = None,
# Source audio
src_audio: torch.Tensor = None,
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
audio_cover_strength: float = 1.0,
# Audio codes
audio_code_string: Optional[str] = None,
# Inpainting
repainting_ranges: Optional[List[Tuple[float, float]]] = None,
repainting_strength: float = 1.0,
# Shape
duration: int = 60,
# Audio Meta
bpm: Optional[int] = 100,
keyscale: Optional[str] = "B minor",
timesignature: Optional[str] = "4",
vocal_language: Optional[str] = "unknown",
# Randomness
seed: int = None,
rand_device: str = "cpu",
# Steps
num_inference_steps: int = 8,
# Scheduler-specific parameters
shift: float = 1.0,
# Progress
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
# Parameters
inputs_posi = {"prompt": prompt, "positive": True}
inputs_nega = {"positive": False}
inputs_shared = {
"cfg_scale": cfg_scale,
"lyrics": lyrics,
"task_type": task_type,
"reference_audios": reference_audios,
"src_audio": src_audio, "audio_cover_strength": audio_cover_strength, "audio_code_string": audio_code_string,
"repainting_ranges": repainting_ranges, "repainting_strength": repainting_strength,
"duration": duration,
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
"seed": seed,
"rand_device": rand_device,
"num_inference_steps": num_inference_steps,
"shift": shift,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
self.switch_noncover_condition(inputs_shared, inputs_posi, inputs_nega, progress_id)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id,
)
inputs_shared["latents"] = self.step(
self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared,
)
# Decode
self.load_models_to_device(['vae'])
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
latents = inputs_shared["latents"].transpose(1, 2)
vae_output = self.vae.decode(latents)
audio_output = self.normalize_audio(vae_output, target_db=-1.0)
audio = self.output_audio_format_check(audio_output)
self.load_models_to_device([])
return audio
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
peak = torch.max(torch.abs(audio))
if peak < 1e-6:
return audio
target_amp = 10 ** (target_db / 20.0)
gain = target_amp / peak
return audio * gain
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0:
return
if inputs_shared.get("shared_noncover", None) is None:
return
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
if progress_id >= cover_steps:
inputs_shared.update(inputs_shared.pop("shared_noncover", {}))
inputs_posi.update(inputs_shared.pop("posi_noncover", {}))
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega.update(inputs_shared.pop("nega_noncover", {}))
class AceStepUnit_TaskTypeChecker(PipelineUnit):
"""Check and compute sequence length from duration."""
def __init__(self):
super().__init__(
input_params=("task_type", "src_audio", "repainting_ranges", "audio_code_string"),
output_params=("task_type",),
)
def process(self, pipe, task_type, src_audio, repainting_ranges, audio_code_string):
assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
if task_type == "cover":
assert (src_audio is not None) or (audio_code_string is not None), "For cover task, either src_audio or audio_code_string must be provided."
elif task_type == "repaint":
assert src_audio is not None, "For repaint task, src_audio must be provided."
assert repainting_ranges is not None and len(repainting_ranges) > 0, "For repaint task, inpainting_ranges must be provided and non-empty."
return {}
class AceStepUnit_PromptEmbedder(PipelineUnit):
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
INSTRUCTION_MAP = {
"text2music": "Fill the audio semantic mask based on the given conditions:",
"cover": "Generate audio semantic tokens based on the given conditions:",
"repaint": "Repaint the mask area based on the given conditions:",
"extract": "Extract the {TRACK_NAME} track from the audio:",
"extract_default": "Extract the track from the audio:",
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
"lego_default": "Generate the track based on the audio context:",
"complete": "Complete the input track with {TRACK_CLASSES}:",
"complete_default": "Complete the input track:",
}
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt", "positive": "positive"},
input_params_nega={"prompt": "prompt", "positive": "positive"},
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language", "task_type"),
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
onload_model_names=("text_encoder",)
)
def _encode_text(self, pipe, text, max_length=256):
"""Encode text using Qwen3-Embedding → [B, T, 1024]."""
text_inputs = pipe.tokenizer(
text,
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(pipe.device)
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
hidden_states = pipe.text_encoder(input_ids, attention_mask)
return hidden_states, attention_mask
def _encode_lyrics(self, pipe, lyric_text, max_length=2048):
text_inputs = pipe.tokenizer(
lyric_text,
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(pipe.device)
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
hidden_states = pipe.text_encoder.model.embed_tokens(input_ids)
return hidden_states, attention_mask
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
bpm = meta_dict.get("bpm", "N/A")
timesignature = meta_dict.get("timesignature", "N/A")
keyscale = meta_dict.get("keyscale", "N/A")
duration = meta_dict.get("duration", 30)
duration = f"{int(duration)} seconds"
return (
f"- bpm: {bpm}\n"
f"- timesignature: {timesignature}\n"
f"- keyscale: {keyscale}\n"
f"- duration: {duration}\n"
)
def process(self, pipe, prompt, positive, lyrics, duration, bpm, keyscale, timesignature, vocal_language, task_type):
if not positive:
return {}
pipe.load_models_to_device(['text_encoder'])
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
INSTRUCTION = self.INSTRUCTION_MAP.get(task_type, self.INSTRUCTION_MAP["text2music"])
prompt = self.SFT_GEN_PROMPT.format(INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
lyric_hidden_states, lyric_attention_mask = self._encode_lyrics(pipe, lyric_text, max_length=2048)
# TODO: remove this
newtext = prompt + "\n\n" + lyric_text
return {
"text_hidden_states": text_hidden_states,
"text_attention_mask": text_attention_mask,
"lyric_hidden_states": lyric_hidden_states,
"lyric_attention_mask": lyric_attention_mask,
}
class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("reference_audios",),
output_params=("reference_latents", "refer_audio_order_mask"),
onload_model_names=("vae",)
)
def process(self, pipe, reference_audios):
if reference_audios is not None:
pipe.load_models_to_device(['vae'])
reference_audios = [
self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device)
for reference_audio in reference_audios
]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
else:
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
def process_reference_audio(self, audio) -> Optional[torch.Tensor]:
if audio.ndim == 3 and audio.shape[0] == 1:
audio = audio.squeeze(0)
target_frames = 30 * 48000
segment_frames = 10 * 48000
if audio.shape[-1] < target_frames:
repeat_times = math.ceil(target_frames / audio.shape[-1])
audio = audio.repeat(1, repeat_times)
total_frames = audio.shape[-1]
segment_size = total_frames // 3
front_start = random.randint(0, max(0, segment_size - segment_frames))
front_audio = audio[:, front_start:front_start + segment_frames]
middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
middle_audio = audio[:, middle_start:middle_start + segment_frames]
back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
back_audio = audio[:, back_start:back_start + segment_frames]
return torch.cat([front_audio, middle_audio, back_audio], dim=-1).unsqueeze(0)
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Infer packed reference-audio latents and order mask."""
refer_audio_order_mask = []
refer_audio_latents = []
for batch_idx, refer_audios in enumerate(refer_audioss):
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
refer_audio_latent = pipe.silence_latent[:, :750, :]
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
else:
for refer_audio in refer_audios:
refer_audio_latent = pipe.vae.encode(refer_audio).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
refer_audio_latents.append(refer_audio_latent)
refer_audio_order_mask.append(batch_idx)
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
return refer_audio_latents, refer_audio_order_mask
class AceStepUnit_ConditionEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
output_params=("encoder_hidden_states", "encoder_attention_mask"),
onload_model_names=("conditioner",),
)
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
pipe.load_models_to_device(['conditioner'])
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
text_hidden_states=inputs_posi.get("text_hidden_states", None),
text_attention_mask=inputs_posi.get("text_attention_mask", None),
lyric_hidden_states=inputs_posi.get("lyric_hidden_states", None),
lyric_attention_mask=inputs_posi.get("lyric_attention_mask", None),
reference_latents=inputs_shared.get("reference_latents", None),
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
)
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
pipe, inputs_posi["prompt"], True, inputs_shared["lyrics"], inputs_shared["duration"],
inputs_shared["bpm"], inputs_shared["keyscale"], inputs_shared["timesignature"],
inputs_shared["vocal_language"], "text2music")
encoder_hidden_states_noncover, encoder_attention_mask_noncover = pipe.conditioner(
**hidden_states_noncover,
reference_latents=inputs_shared.get("reference_latents", None),
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
)
duration = inputs_shared["context_latents"].shape[1] * 1920 / pipe.vae.sampling_rate
context_latents_noncover = AceStepUnit_ContextLatentBuilder().process(pipe, duration, None, None)["context_latents"]
inputs_shared["shared_noncover"] = {"context_latents": context_latents_noncover}
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
if inputs_shared["cfg_scale"] != 1.0:
inputs_shared["nega_noncover"] = {
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(
dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device,
),
"encoder_attention_mask": encoder_attention_mask_noncover,
}
return inputs_shared, inputs_posi, inputs_nega
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("duration", "src_audio", "audio_code_string", "task_type", "repainting_ranges", "repainting_strength"),
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
onload_model_names=("vae", "tokenizer_model",),
)
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
available = pipe.silence_latent.shape[1]
if length <= available:
return pipe.silence_latent[0, :length, :]
repeats = (length + available - 1) // available
tiled = pipe.silence_latent[0].repeat(repeats, 1)
return tiled[:length, :]
def tokenize(self, tokenizer, x, silence_latent, pool_window_size):
if x.shape[1] % pool_window_size != 0:
pad_len = pool_window_size - (x.shape[1] % pool_window_size)
x = torch.cat([x, silence_latent[:1,:pad_len].repeat(x.shape[0],1,1)], dim=1)
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=pool_window_size)
quantized, indices = tokenizer(x)
return quantized
@staticmethod
def _parse_audio_code_string(code_str: str) -> list:
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
if not code_str:
return []
try:
codes = []
max_audio_code = 63999
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
code_value = int(x)
codes.append(max(0, min(code_value, max_audio_code)))
except Exception as e:
raise ValueError(f"Invalid audio_code_string format: {e}")
return codes
def pad_src_audio(self, pipe, src_audio, task_type, repainting_ranges):
if task_type != "repaint" or repainting_ranges is None:
return src_audio, repainting_ranges, None, None
min_left = min([start for start, end in repainting_ranges])
max_right = max([end for start, end in repainting_ranges])
total_length = src_audio.shape[-1] // pipe.vae.sampling_rate
pad_left = max(0, -min_left)
pad_right = max(0, max_right - total_length)
if pad_left > 0 or pad_right > 0:
padding_frames_left, padding_frames_right = pad_left * pipe.vae.sampling_rate, pad_right * pipe.vae.sampling_rate
src_audio = F.pad(src_audio, (padding_frames_left, padding_frames_right), value=0.0)
repainting_ranges = [(start + pad_left, end + pad_left) for start, end in repainting_ranges]
return src_audio, repainting_ranges, pad_left, pad_right
def parse_repaint_masks(self, pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right):
if task_type != "repaint" or repainting_ranges is None:
return None, src_latents
# let repainting area be repainting_strength, non-repainting area be 0.0, and blend at the boundary with cf_frames.
max_latent_length = src_latents.shape[1]
denoise_mask = torch.zeros((1, max_latent_length, 1), dtype=pipe.torch_dtype, device=pipe.device)
for start, end in repainting_ranges:
start_frame = start * pipe.vae.sampling_rate // 1920
end_frame = end * pipe.vae.sampling_rate // 1920
denoise_mask[:, start_frame:end_frame, :] = repainting_strength
# set padding areas to 1.0 (full repaint) to avoid artifacts at the boundaries caused by padding
pad_left_frames = pad_left * pipe.vae.sampling_rate // 1920
pad_right_frames = pad_right * pipe.vae.sampling_rate // 1920
denoise_mask[:, :pad_left_frames, :] = 1
denoise_mask[:, max_latent_length - pad_right_frames:, :] = 1
silent_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
src_latents = src_latents * (1 - denoise_mask) + silent_latents * denoise_mask
return denoise_mask, src_latents
def process(self, pipe, duration, src_audio, audio_code_string, task_type=None, repainting_ranges=None, repainting_strength=None):
# get src_latents from audio_code_string > src_audio > silence
source_latents = None
denoise_mask = None
if audio_code_string is not None:
# use audio_cede_string to get src_latents.
pipe.load_models_to_device(self.onload_model_names)
code_ids = self._parse_audio_code_string(audio_code_string)
quantizer = pipe.tokenizer_model.tokenizer.quantizer.to(device=pipe.device)
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
codes = quantizer.get_codes_from_indices(indices)
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
quantized = quantizer.project_out(quantized)
src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
max_latent_length = src_latents.shape[1]
elif src_audio is not None:
# use src_audio to get src_latents.
pipe.load_models_to_device(self.onload_model_names)
src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
src_audio = torch.clamp(src_audio, -1.0, 1.0)
src_audio, repainting_ranges, pad_left, pad_right = self.pad_src_audio(pipe, src_audio, task_type, repainting_ranges)
src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
source_latents = src_latents # cache for potential use in audio inpainting tasks
denoise_mask, src_latents = self.parse_repaint_masks(pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right)
if task_type == "cover":
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
max_latent_length = src_latents.shape[1]
else:
# use silence latents.
max_latent_length = int(duration * pipe.sample_rate // 1920)
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
return {"context_latents": context_latents, "attention_mask": attention_mask, "src_latents": source_latents, "denoise_mask": denoise_mask}
class AceStepUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_latents", "seed", "rand_device", "src_latents"),
output_params=("noise",),
)
def process(self, pipe, context_latents, seed, rand_device, src_latents):
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
if src_latents is not None:
noise = pipe.scheduler.add_noise(src_latents, noise, timestep=pipe.scheduler.timesteps[0])
return {"noise": noise}
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
"""Only for training."""
def __init__(self):
super().__init__(
input_params=("noise", "input_audio"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe, noise, input_audio):
if input_audio is None:
return {"latents": noise}
if pipe.scheduler.training:
pipe.load_models_to_device(self.onload_model_names)
input_audio, sample_rate = input_audio
input_audio = torch.clamp(input_audio, -1.0, 1.0)
if input_audio.dim() == 2:
input_audio = input_audio.unsqueeze(0)
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
input_latents = input_latents[:, :noise.shape[1]]
return {"input_latents": input_latents}
def model_fn_ace_step(
dit: AceStepDiTModel,
latents=None,
timestep=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
context_latents=None,
attention_mask=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
decoder_outputs = dit(
hidden_states=latents,
timestep=timestep,
timestep_r=timestep,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
context_latents=context_latents,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)[0]
return decoder_outputs

View File

@@ -1,230 +0,0 @@
import torch
from PIL import Image
from tqdm import tqdm
from typing import Union
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion.ddim_scheduler import DDIMScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from transformers import AutoTokenizer, CLIPTextModel
from ..models.stable_diffusion_text_encoder import SDTextEncoder
from ..models.stable_diffusion_unet import UNet2DConditionModel
from ..models.stable_diffusion_vae import StableDiffusionVAE
class StableDiffusionPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.float16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=8, width_division_factor=8,
)
self.scheduler = DDIMScheduler()
self.text_encoder: SDTextEncoder = None
self.unet: UNet2DConditionModel = None
self.vae: StableDiffusionVAE = None
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("unet",)
self.units = [
SDUnit_ShapeChecker(),
SDUnit_PromptEmbedder(),
SDUnit_NoiseInitializer(),
SDUnit_InputImageEmbedder(),
]
self.model_fn = model_fn_stable_diffusion
self.compilable_models = ["unet"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.float16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
vram_limit: float = None,
):
pipe = StableDiffusionPipeline(device=device, torch_dtype=torch_dtype)
# Override vram_config to use the specified torch_dtype for all models
for mc in model_configs:
mc._vram_config_override = {
'onload_dtype': torch_dtype,
'computation_dtype': torch_dtype,
}
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
pipe.unet = model_pool.fetch_model("stable_diffusion_unet")
pipe.vae = model_pool.fetch_model("stable_diffusion_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 7.5,
height: int = 512,
width: int = 512,
seed: int = None,
rand_device: str = "cpu",
num_inference_steps: int = 50,
eta: float = 0.0,
guidance_rescale: float = 0.0,
progress_bar_cmd=tqdm,
):
# 1. Scheduler
self.scheduler.set_timesteps(
num_inference_steps, eta=eta,
)
# 2. Three-dict input preparation
inputs_posi = {"prompt": prompt}
inputs_nega = {"negative_prompt": negative_prompt}
inputs_shared = {
"cfg_scale": cfg_scale,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"guidance_rescale": guidance_rescale,
}
# 3. Unit chain execution
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["latents"] = self.step(
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
)
# 5. VAE decode
self.load_models_to_device(['vae'])
latents = inputs_shared["latents"] / self.vae.scaling_factor
image = self.vae.decode(latents)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class SDUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: StableDiffusionPipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class SDUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "negative_prompt"},
output_params=("prompt_embeds",),
onload_model_names=("text_encoder",)
)
def encode_prompt(
self,
pipe: StableDiffusionPipeline,
prompt: str,
device: torch.device,
) -> torch.Tensor:
text_inputs = pipe.tokenizer(
prompt,
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_embeds = pipe.text_encoder(text_input_ids)
# TextEncoder returns (last_hidden_state, hidden_states) or just last_hidden_state.
# last_hidden_state is the post-final-layer-norm output, matching diffusers encode_prompt.
if isinstance(prompt_embeds, tuple):
prompt_embeds = prompt_embeds[0]
return prompt_embeds
def process(self, pipe: StableDiffusionPipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds}
class SDUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: StableDiffusionPipeline, height, width, seed, rand_device):
noise = pipe.generate_noise(
(1, pipe.unet.in_channels, height // 8, width // 8),
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
)
return {"noise": noise}
class SDUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
if input_image is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
if pipe.scheduler.training:
return {"latents": latents, "input_latents": input_latents}
else:
return {"latents": latents}
def model_fn_stable_diffusion(
unet: UNet2DConditionModel,
latents=None,
timestep=None,
prompt_embeds=None,
cross_attention_kwargs=None,
timestep_cond=None,
added_cond_kwargs=None,
**kwargs,
):
# SD timestep is already in 0-999 range, no scaling needed
noise_pred = unet(
latents,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
timestep_cond=timestep_cond,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
return noise_pred

View File

@@ -1,331 +0,0 @@
import torch
from PIL import Image
from tqdm import tqdm
from typing import Union
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion.ddim_scheduler import DDIMScheduler
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from transformers import AutoTokenizer, CLIPTextModel
from ..models.stable_diffusion_text_encoder import SDTextEncoder
from ..models.stable_diffusion_xl_unet import SDXLUNet2DConditionModel
from ..models.stable_diffusion_xl_text_encoder import SDXLTextEncoder2
from ..models.stable_diffusion_vae import StableDiffusionVAE
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""Rescale noise_cfg based on guidance_rescale to prevent overexposure.
Based on Section 3.4 from "Common Diffusion Noise Schedules and Sample Steps are Flawed"
https://huggingface.co/papers/2305.08891
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class StableDiffusionXLPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=8, width_division_factor=8,
)
self.scheduler = DDIMScheduler()
self.text_encoder: SDTextEncoder = None
self.text_encoder_2: SDXLTextEncoder2 = None
self.unet: SDXLUNet2DConditionModel = None
self.vae: StableDiffusionVAE = None
self.tokenizer: AutoTokenizer = None
self.tokenizer_2: AutoTokenizer = None
self.in_iteration_models = ("unet",)
self.units = [
SDXLUnit_ShapeChecker(),
SDXLUnit_PromptEmbedder(),
SDXLUnit_NoiseInitializer(),
SDXLUnit_InputImageEmbedder(),
SDXLUnit_AddTimeIdsComputer(),
]
self.model_fn = model_fn_stable_diffusion_xl
self.compilable_models = ["unet"]
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = None,
tokenizer_2_config: ModelConfig = None,
vram_limit: float = None,
):
pipe = StableDiffusionXLPipeline(device=device, torch_dtype=torch_dtype)
# Override vram_config to use the specified torch_dtype for all models
for mc in model_configs:
mc._vram_config_override = {
'onload_dtype': torch_dtype,
'computation_dtype': torch_dtype,
}
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
pipe.text_encoder = model_pool.fetch_model("stable_diffusion_text_encoder")
pipe.text_encoder_2 = model_pool.fetch_model("stable_diffusion_xl_text_encoder")
pipe.unet = model_pool.fetch_model("stable_diffusion_xl_unet")
pipe.vae = model_pool.fetch_model("stable_diffusion_xl_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
if tokenizer_2_config is not None:
tokenizer_2_config.download_if_necessary()
pipe.tokenizer_2 = AutoTokenizer.from_pretrained(tokenizer_2_config.path)
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
prompt: str,
negative_prompt: str = "",
cfg_scale: float = 5.0,
height: int = 1024,
width: int = 1024,
seed: int = None,
rand_device: str = "cpu",
num_inference_steps: int = 50,
guidance_rescale: float = 0.0,
progress_bar_cmd=tqdm,
):
# 1. Scheduler
self.scheduler.set_timesteps(num_inference_steps)
# 2. Three-dict input preparation
inputs_posi = {
"prompt": prompt,
}
inputs_nega = {
"prompt": negative_prompt,
}
inputs_shared = {
"cfg_scale": cfg_scale,
"height": height, "width": width,
"seed": seed, "rand_device": rand_device,
"guidance_rescale": guidance_rescale,
"crops_coords_top_left": (0, 0),
}
# 3. Unit chain execution
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
unit, self, inputs_shared, inputs_posi, inputs_nega
)
# 4. Denoise loop
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred = self.cfg_guided_model_fn(
self.model_fn, cfg_scale,
inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
# Apply guidance_rescale
if guidance_rescale > 0.0:
# cfg_guided_model_fn already applied CFG, now apply rescale
# We need the text-only prediction for rescale
noise_pred_text = self.model_fn(
self.unet,
inputs_shared["latents"],
timestep,
inputs_posi["prompt_embeds"],
pooled_prompt_embeds=inputs_posi["pooled_prompt_embeds"],
add_time_ids=inputs_posi["add_time_ids"],
)
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
)
inputs_shared["latents"] = self.step(
self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared
)
# 6. VAE decode
self.load_models_to_device(['vae'])
latents = inputs_shared["latents"] / self.vae.scaling_factor
image = self.vae.decode(latents)
image = self.vae_output_to_image(image)
self.load_models_to_device([])
return image
class SDXLUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("height", "width"),
)
def process(self, pipe: StableDiffusionXLPipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
return {"height": height, "width": width}
class SDXLUnit_PromptEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt"},
input_params_nega={"prompt": "prompt"},
output_params=("prompt_embeds", "pooled_prompt_embeds"),
onload_model_names=("text_encoder", "text_encoder_2")
)
def encode_prompt(
self,
pipe: StableDiffusionXLPipeline,
prompt: str,
device: torch.device,
) -> tuple:
"""Encode prompt using both text encoders (same prompt for both).
Returns (prompt_embeds, pooled_prompt_embeds):
- prompt_embeds: concat(encoder1_output, encoder2_output) -> (B, 77, 2048)
- pooled_prompt_embeds: encoder2 pooled output -> (B, 1280)
"""
# Text Encoder 1 (CLIP-L, 768-dim)
text_input_ids_1 = pipe.tokenizer(
prompt,
padding="max_length",
max_length=pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.to(device)
prompt_embeds_1 = pipe.text_encoder(text_input_ids_1)
if isinstance(prompt_embeds_1, tuple):
prompt_embeds_1 = prompt_embeds_1[0]
# Text Encoder 2 (CLIP-bigG, 1280-dim) — uses penultimate hidden states + pooled
text_input_ids_2 = pipe.tokenizer_2(
prompt,
padding="max_length",
max_length=pipe.tokenizer_2.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.to(device)
# SDXLTextEncoder2 forward returns (text_embeds/pooled, hidden_states_tuple)
pooled_prompt_embeds, hidden_states = pipe.text_encoder_2(text_input_ids_2, output_hidden_states=True)
# Use penultimate hidden state (same as diffusers: hidden_states[-2])
prompt_embeds_2 = hidden_states[-2]
# Concatenate both encoder outputs along feature dimension
prompt_embeds = torch.cat([prompt_embeds_1, prompt_embeds_2], dim=-1)
return prompt_embeds, pooled_prompt_embeds
def process(self, pipe: StableDiffusionXLPipeline, prompt):
pipe.load_models_to_device(self.onload_model_names)
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
class SDXLUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "seed", "rand_device"),
output_params=("noise",),
)
def process(self, pipe: StableDiffusionXLPipeline, height, width, seed, rand_device):
noise = pipe.generate_noise(
(1, pipe.unet.in_channels, height // 8, width // 8),
seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype
)
return {"noise": noise}
class SDXLUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",),
)
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
if input_image is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_tensor = pipe.preprocess_image(input_image)
input_latents = pipe.vae.encode(input_tensor).sample() * pipe.vae.scaling_factor
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
if pipe.scheduler.training:
return {"latents": latents, "input_latents": input_latents}
else:
return {"latents": latents}
class SDXLUnit_AddTimeIdsComputer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width"),
output_params=("add_time_ids",),
)
def _get_add_time_ids(self, pipe, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
addition_time_embed_dim = pipe.unet.add_time_proj.num_channels
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
f"but a vector of {passed_add_embed_dim} was created."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype, device=pipe.device)
return add_time_ids
def process(self, pipe: StableDiffusionXLPipeline, height, width):
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
add_time_ids = self._get_add_time_ids(
pipe, original_size, crops_coords_top_left, target_size,
dtype=pipe.torch_dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
return {"add_time_ids": add_time_ids}
def model_fn_stable_diffusion_xl(
unet: SDXLUNet2DConditionModel,
latents=None,
timestep=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
add_time_ids=None,
cross_attention_kwargs=None,
timestep_cond=None,
**kwargs,
):
"""SDXL model forward with added_cond_kwargs for micro-conditioning."""
added_cond_kwargs = {
"text_embeds": pooled_prompt_embeds,
"time_ids": add_time_ids,
}
noise_pred = unet(
latents,
timestep,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
timestep_cond=timestep_cond,
return_dict=False,
)[0]
return noise_pred

View File

@@ -99,6 +99,7 @@ def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend
""" """
if waveform.dim() == 3: if waveform.dim() == 3:
waveform = waveform[0] waveform = waveform[0]
waveform.cpu()
if backend == "torchcodec": if backend == "torchcodec":
from torchcodec.encoders import AudioEncoder from torchcodec.encoders import AudioEncoder

View File

@@ -0,0 +1,13 @@
def AceStepConditionEncoderStateDictConverter(state_dict):
new_state_dict = {}
prefix = "encoder."
for key in state_dict:
if key.startswith(prefix):
new_key = key[len(prefix):]
new_state_dict[new_key] = state_dict[key]
if "null_condition_emb" in state_dict:
new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]
return new_state_dict

View File

@@ -0,0 +1,10 @@
def AceStepDiTModelStateDictConverter(state_dict):
new_state_dict = {}
prefix = "decoder."
for key in state_dict:
if key.startswith(prefix):
new_key = key[len(prefix):]
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -0,0 +1,15 @@
def AceStepTextEncoderStateDictConverter(state_dict):
new_state_dict = {}
prefix = "model."
nested_prefix = "model.model."
for key in state_dict:
if key.startswith(nested_prefix):
new_key = key
elif key.startswith(prefix):
new_key = "model." + key
else:
new_key = "model." + key
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -0,0 +1,8 @@
def AceStepTokenizerStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if key.startswith("tokenizer.") or key.startswith("detokenizer."):
new_state_dict[key] = state_dict[key]
return new_state_dict

View File

@@ -1,7 +0,0 @@
def SDTextEncoderStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if key.startswith("text_model.") and "position_ids" not in key:
new_key = "model." + key
new_state_dict[new_key] = state_dict[key]
return new_state_dict

View File

@@ -1,18 +0,0 @@
def SDVAEStateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if ".query." in key:
new_key = key.replace(".query.", ".to_q.")
new_state_dict[new_key] = state_dict[key]
elif ".key." in key:
new_key = key.replace(".key.", ".to_k.")
new_state_dict[new_key] = state_dict[key]
elif ".value." in key:
new_key = key.replace(".value.", ".to_v.")
new_state_dict[new_key] = state_dict[key]
elif ".proj_attn." in key:
new_key = key.replace(".proj_attn.", ".to_out.0.")
new_state_dict[new_key] = state_dict[key]
else:
new_state_dict[key] = state_dict[key]
return new_state_dict

View File

@@ -1,13 +0,0 @@
import torch
def SDXLTextEncoder2StateDictConverter(state_dict):
new_state_dict = {}
for key in state_dict:
if key == "text_projection.weight":
val = state_dict[key]
new_state_dict["model.text_projection.weight"] = val.float() if val.dtype == torch.float16 else val
elif key.startswith("text_model.") and "position_ids" not in key:
new_key = "model." + key
val = state_dict[key]
new_state_dict[new_key] = val.float() if val.dtype == torch.float16 else val
return new_state_dict

View File

@@ -0,0 +1,164 @@
# ACE-Step
ACE-Step 1.5 is an open-source music generation model based on DiT architecture, supporting text-to-music, audio cover, repainting and other functionalities, running efficiently on consumer-grade hardware.
## Installation
Before performing model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Running the following code will load the [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 3GB VRAM.
```python
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
## Model Inference
The model is loaded via `AceStepPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
The input parameters for `AceStepPipeline` inference include:
* `prompt`: Text description of the music.
* `cfg_scale`: Classifier-free guidance scale, defaults to 1.0.
* `lyrics`: Lyrics text.
* `task_type`: Task type,可选 values include `"text2music"` (text-to-music), `"cover"` (audio cover), `"repaint"` (repainting), defaults to `"text2music"`.
* `reference_audios`: List of reference audio tensors for timbre reference.
* `src_audio`: Source audio tensor for cover or repaint tasks.
* `denoising_strength`: Denoising strength, controlling how much the output is influenced by source audio, defaults to 1.0.
* `audio_cover_strength`: Audio cover step ratio, controlling how many steps use cover condition in cover tasks, defaults to 1.0.
* `audio_code_string`: Input audio code string for cover tasks with discrete audio codes.
* `repainting_ranges`: List of repainting time ranges (tuples of floats, in seconds) for repaint tasks.
* `repainting_strength`: Repainting intensity, controlling the degree of change in repainted areas, defaults to 1.0.
* `duration`: Audio duration in seconds, defaults to 60.
* `bpm`: Beats per minute, defaults to 100.
* `keyscale`: Musical key scale, defaults to "B minor".
* `timesignature`: Time signature, defaults to "4".
* `vocal_language`: Vocal language, defaults to "unknown".
* `seed`: Random seed.
* `rand_device`: Device for noise generation, defaults to "cpu".
* `num_inference_steps`: Number of inference steps, defaults to 8.
* `shift`: Timestep shift parameter for the scheduler, defaults to 1.0.
## Model Training
Models in the ace_step series are trained uniformly via `examples/ace_step/model_training/train.py`. The script parameters include:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Path to the dataset metadata file.
* `--dataset_repeat`: Number of dataset repeats per epoch.
* `--dataset_num_workers`: Number of processes per DataLoader.
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths to load models from, in JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
* Basic Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
* `--weight_decay`: Weight decay magnitude.
* `--task`: Training task, defaults to `sft`.
* Output Configuration
* `--output_path`: Path to save the model.
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
* `--save_steps`: Interval in training steps to save the model.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to LoRA checkpoint.
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Resolution Configuration
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
* `--num_frames`: Number of frames for video (video generation models only).
* ACE-Step Specific Parameters
* `--tokenizer_path`: Tokenizer path, in format model_id:origin_pattern.
* `--silence_latent_path`: Silence latent path, in format model_id:origin_pattern.
* `--initialize_model_on_cpu`: Whether to initialize models on CPU.
### Example Dataset
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -1,141 +0,0 @@
# Stable Diffusion XL
Stable Diffusion XL (SDXL) is an open-source diffusion-based text-to-image generation model developed by Stability AI, supporting 1024x1024 resolution high-quality text-to-image generation with a dual text encoder (CLIP-L + CLIP-bigG) architecture.
## Installation
Before performing model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Running the following code will quickly load the [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 6GB VRAM.
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars",
negative_prompt="",
cfg_scale=5.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image.jpg")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
## Model Inference
The model is loaded via `StableDiffusionXLPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
The input parameters for `StableDiffusionXLPipeline` inference include:
* `prompt`: Text prompt.
* `negative_prompt`: Negative prompt, defaults to an empty string.
* `cfg_scale`: Classifier-Free Guidance scale factor, default 5.0.
* `height`: Output image height, default 1024.
* `width`: Output image width, default 1024.
* `seed`: Random seed, defaults to a random value if not set.
* `rand_device`: Noise generation device, defaults to "cpu".
* `num_inference_steps`: Number of inference steps, default 50.
* `guidance_rescale`: Guidance rescale factor, default 0.0.
* `progress_bar_cmd`: Progress bar callback function.
> `StableDiffusionXLPipeline` requires dual tokenizer configurations (`tokenizer_config` and `tokenizer_2_config`), corresponding to the CLIP-L and CLIP-bigG text encoders.
## Model Training
Models in the stable_diffusion_xl series are trained via `examples/stable_diffusion_xl/model_training/train.py`. The script parameters include:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Path to the dataset metadata file.
* `--dataset_repeat`: Number of dataset repeats per epoch.
* `--dataset_num_workers`: Number of processes per DataLoader.
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths to load models from, in JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
* Basic Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
* `--weight_decay`: Weight decay magnitude.
* `--task`: Training task, defaults to `sft`.
* Output Configuration
* `--output_path`: Path to save the model.
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
* `--save_steps`: Interval in training steps to save the model.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to LoRA checkpoint.
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Resolution Configuration
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
* `--num_frames`: Number of frames for video (video generation models only).
* Stable Diffusion XL Specific Parameters
* `--tokenizer_path`: Path to the first tokenizer.
* `--tokenizer_2_path`: Path to the second tokenizer, defaults to `stabilityai/stable-diffusion-xl-base-1.0:tokenizer_2/`.
Example dataset download:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/*" --local_dir ./data/diffsynth_example_dataset
```
[stable-diffusion-xl-base-1.0 training scripts](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -1,138 +0,0 @@
# Stable Diffusion
Stable Diffusion is an open-source diffusion-based text-to-image generation model developed by Stability AI, supporting 512x512 resolution text-to-image generation.
## Installation
Before performing model inference and training, please install DiffSynth-Studio first.
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
## Quick Start
Running the following code will quickly load the [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 2GB VRAM.
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")
```
## Model Overview
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|-|-|-|-|-|-|-|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
## Model Inference
The model is loaded via `StableDiffusionPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
The input parameters for `StableDiffusionPipeline` inference include:
* `prompt`: Text prompt.
* `negative_prompt`: Negative prompt, defaults to an empty string.
* `cfg_scale`: Classifier-Free Guidance scale factor, default 7.5.
* `height`: Output image height, default 512.
* `width`: Output image width, default 512.
* `seed`: Random seed, defaults to a random value if not set.
* `rand_device`: Noise generation device, defaults to "cpu".
* `num_inference_steps`: Number of inference steps, default 50.
* `eta`: DDIM scheduler eta parameter, default 0.0.
* `guidance_rescale`: Guidance rescale factor, default 0.0.
* `progress_bar_cmd`: Progress bar callback function.
## Model Training
Models in the stable_diffusion series are trained via `examples/stable_diffusion/model_training/train.py`. The script parameters include:
* General Training Parameters
* Dataset Configuration
* `--dataset_base_path`: Root directory of the dataset.
* `--dataset_metadata_path`: Path to the dataset metadata file.
* `--dataset_repeat`: Number of dataset repeats per epoch.
* `--dataset_num_workers`: Number of processes per DataLoader.
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
* Model Loading Configuration
* `--model_paths`: Paths to load models from, in JSON format.
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
* Basic Training Configuration
* `--learning_rate`: Learning rate.
* `--num_epochs`: Number of epochs.
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
* `--weight_decay`: Weight decay magnitude.
* `--task`: Training task, defaults to `sft`.
* Output Configuration
* `--output_path`: Path to save the model.
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
* `--save_steps`: Interval in training steps to save the model.
* LoRA Configuration
* `--lora_base_model`: Which model to add LoRA to.
* `--lora_target_modules`: Which layers to add LoRA to.
* `--lora_rank`: Rank of LoRA.
* `--lora_checkpoint`: Path to LoRA checkpoint.
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
* Gradient Configuration
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
* Resolution Configuration
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
* `--num_frames`: Number of frames for video (video generation models only).
* Stable Diffusion Specific Parameters
* `--tokenizer_path`: Tokenizer path, defaults to `AI-ModelScope/stable-diffusion-v1-5:tokenizer/`.
Example dataset download:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/*" --local_dir ./data/diffsynth_example_dataset
```
[stable-diffusion-v1-5 training scripts](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).

View File

@@ -32,8 +32,7 @@ Welcome to DiffSynth-Studio's Documentation
Model_Details/LTX-2 Model_Details/LTX-2
Model_Details/ERNIE-Image Model_Details/ERNIE-Image
Model_Details/JoyAI-Image Model_Details/JoyAI-Image
Model_Details/Stable-Diffusion Model_Details/ACE-Step
Model_Details/Stable-Diffusion-XL
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2

View File

@@ -0,0 +1,164 @@
# ACE-Step
ACE-Step 1.5 是一个开源音乐生成模型,基于 DiT 架构,支持文生音乐、音频翻唱、局部重绘等多种功能,可在消费级硬件上高效运行。
## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
## 快速开始
运行以下代码可以快速加载 [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
```python
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
```
## 模型总览
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
## 模型推理
模型通过 `AceStepPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
`AceStepPipeline` 推理的输入参数包括:
* `prompt`: 音乐文本描述。
* `cfg_scale`: 分类器无条件引导比例,默认为 1.0。
* `lyrics`: 歌词文本。
* `task_type`: 任务类型,可选值包括 `"text2music"`(文生音乐)、`"cover"`(音频翻唱)、`"repaint"`(局部重绘),默认为 `"text2music"`
* `reference_audios`: 参考音频列表Tensor 列表),用于提供音色参考。
* `src_audio`: 源音频Tensor用于 cover 或 repaint 任务。
* `denoising_strength`: 降噪强度,控制输出受源音频的影响程度,默认为 1.0。
* `audio_cover_strength`: 音频翻唱步数比例,控制 cover 任务中前多少步使用翻唱条件,默认为 1.0。
* `audio_code_string`: 输入音频码字符串,用于 cover 任务中直接传入离散音频码。
* `repainting_ranges`: 重绘时间区间(浮点元组列表,单位为秒),用于 repaint 任务。
* `repainting_strength`: 重绘强度,控制重绘区域的变化程度,默认为 1.0。
* `duration`: 音频时长(秒),默认为 60。
* `bpm`: 每分钟节拍数,默认为 100。
* `keyscale`: 音阶调式,默认为 "B minor"。
* `timesignature`: 拍号,默认为 "4"。
* `vocal_language`: 演唱语言,默认为 "unknown"。
* `seed`: 随机种子。
* `rand_device`: 噪声生成设备,默认为 "cpu"。
* `num_inference_steps`: 推理步数,默认为 8。
* `shift`: 调度器时间偏移参数,默认为 1.0。
## 模型训练
ace_step 系列模型统一通过 `examples/ace_step/model_training/train.py` 进行训练,脚本的参数包括:
* 通用训练参数
* 数据集基础配置
* `--dataset_base_path`: 数据集的根目录。
* `--dataset_metadata_path`: 数据集的元数据文件路径。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
* 模型加载配置
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
* 训练基础配置
* `--learning_rate`: 学习率。
* `--num_epochs`: 轮数Epoch
* `--trainable_models`: 可训练的模型,例如 `dit``vae``text_encoder`
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
* `--weight_decay`: 权重衰减大小。
* `--task`: 训练任务,默认为 `sft`
* 输出配置
* `--output_path`: 模型保存路径。
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
* `--save_steps`: 保存模型的训练步数间隔。
* LoRA 配置
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪些层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`
* 梯度配置
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `--gradient_accumulation_steps`: 梯度累积步数。
* 分辨率配置
* `--height`: 图像/视频的高度。留空启用动态分辨率。
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
* `--num_frames`: 视频的帧数(仅视频生成模型)。
* ACE-Step 专有参数
* `--tokenizer_path`: Tokenizer 路径,格式为 model_id:origin_pattern。
* `--silence_latent_path`: 静音隐变量路径,格式为 model_id:origin_pattern。
* `--initialize_model_on_cpu`: 是否在 CPU 上初始化模型。
### 样例数据集
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
```
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。

View File

@@ -1,141 +0,0 @@
# Stable Diffusion XL
Stable Diffusion XL (SDXL) 是由 Stability AI 开发的开源扩散式文本到图像生成模型,支持 1024x1024 分辨率的高质量文本到图像生成采用双文本编码器CLIP-L + CLIP-bigG架构。
## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
## 快速开始
运行以下代码可以快速加载 [stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 6GB 显存即可运行。
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionXLPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars",
negative_prompt="",
cfg_scale=5.0,
height=1024,
width=1024,
seed=42,
num_inference_steps=50,
)
image.save("image.jpg")
```
## 模型总览
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[stabilityai/stable-diffusion-xl-base-1.0](https://www.modelscope.cn/models/stabilityai/stable-diffusion-xl-base-1.0)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_inference_low_vram/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/validate_lora/stable-diffusion-xl-base-1.0.py)|
## 模型推理
模型通过 `StableDiffusionXLPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
`StableDiffusionXLPipeline` 的推理输入参数包括:
* `prompt`: 文本提示词。
* `negative_prompt`: 负面提示词,默认为空字符串。
* `cfg_scale`: Classifier-Free Guidance 缩放系数,默认 5.0。
* `height`: 输出图像高度,默认 1024。
* `width`: 输出图像宽度,默认 1024。
* `seed`: 随机种子,默认不设置时使用随机种子。
* `rand_device`: 噪声生成设备,默认 "cpu"。
* `num_inference_steps`: 推理步数,默认 50。
* `guidance_rescale`: Guidance rescale 系数,默认 0.0。
* `progress_bar_cmd`: 进度条回调函数。
> `StableDiffusionXLPipeline` 需要双 tokenizer 配置(`tokenizer_config` 和 `tokenizer_2_config`),分别对应 CLIP-L 和 CLIP-bigG 文本编码器。
## 模型训练
stable_diffusion_xl 系列模型通过 `examples/stable_diffusion_xl/model_training/train.py` 进行训练,脚本的参数包括:
* 通用训练参数
* 数据集基础配置
* `--dataset_base_path`: 数据集的根目录。
* `--dataset_metadata_path`: 数据集的元数据文件路径。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
* 模型加载配置
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
* 训练基础配置
* `--learning_rate`: 学习率。
* `--num_epochs`: 轮数Epoch
* `--trainable_models`: 可训练的模型,例如 `dit``vae``text_encoder`
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
* `--weight_decay`: 权重衰减大小。
* `--task`: 训练任务,默认为 `sft`
* 输出配置
* `--output_path`: 模型保存路径。
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
* `--save_steps`: 保存模型的训练步数间隔。
* LoRA 配置
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪些层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`
* 梯度配置
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `--gradient_accumulation_steps`: 梯度累积步数。
* 分辨率配置
* `--height`: 图像/视频的高度。留空启用动态分辨率。
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
* `--num_frames`: 视频的帧数(仅视频生成模型)。
* Stable Diffusion XL 专有参数
* `--tokenizer_path`: 第一个 Tokenizer 路径。
* `--tokenizer_2_path`: 第二个 Tokenizer 路径,默认为 `stabilityai/stable-diffusion-xl-base-1.0:tokenizer_2/`
样例数据集下载:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/*" --local_dir ./data/diffsynth_example_dataset
```
[stable-diffusion-xl-base-1.0 训练脚本](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion_xl/model_training/lora/stable-diffusion-xl-base-1.0.sh)
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。

View File

@@ -1,138 +0,0 @@
# Stable Diffusion
Stable Diffusion 是由 Stability AI 开发的开源扩散式文本到图像生成模型,支持 512x512 分辨率的文本到图像生成。
## 安装
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
```shell
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
```
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
## 快速开始
运行以下代码可以快速加载 [AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 2GB 显存即可运行。
```python
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")
```
## 模型总览
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-|-|
|[AI-ModelScope/stable-diffusion-v1-5](https://www.modelscope.cn/models/AI-ModelScope/stable-diffusion-v1-5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_inference_low_vram/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/validate_lora/stable-diffusion-v1-5.py)|
## 模型推理
模型通过 `StableDiffusionPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
`StableDiffusionPipeline` 的推理输入参数包括:
* `prompt`: 文本提示词。
* `negative_prompt`: 负面提示词,默认为空字符串。
* `cfg_scale`: Classifier-Free Guidance 缩放系数,默认 7.5。
* `height`: 输出图像高度,默认 512。
* `width`: 输出图像宽度,默认 512。
* `seed`: 随机种子,默认不设置时使用随机种子。
* `rand_device`: 噪声生成设备,默认 "cpu"。
* `num_inference_steps`: 推理步数,默认 50。
* `eta`: DDIM 调度器的 eta 参数,默认 0.0。
* `guidance_rescale`: Guidance rescale 系数,默认 0.0。
* `progress_bar_cmd`: 进度条回调函数。
## 模型训练
stable_diffusion 系列模型通过 `examples/stable_diffusion/model_training/train.py` 进行训练,脚本的参数包括:
* 通用训练参数
* 数据集基础配置
* `--dataset_base_path`: 数据集的根目录。
* `--dataset_metadata_path`: 数据集的元数据文件路径。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
* 模型加载配置
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
* 训练基础配置
* `--learning_rate`: 学习率。
* `--num_epochs`: 轮数Epoch
* `--trainable_models`: 可训练的模型,例如 `dit``vae``text_encoder`
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
* `--weight_decay`: 权重衰减大小。
* `--task`: 训练任务,默认为 `sft`
* 输出配置
* `--output_path`: 模型保存路径。
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
* `--save_steps`: 保存模型的训练步数间隔。
* LoRA 配置
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪些层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`
* 梯度配置
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
* `--gradient_accumulation_steps`: 梯度累积步数。
* 分辨率配置
* `--height`: 图像/视频的高度。留空启用动态分辨率。
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
* `--num_frames`: 视频的帧数(仅视频生成模型)。
* Stable Diffusion 专有参数
* `--tokenizer_path`: Tokenizer 路径,默认为 `AI-ModelScope/stable-diffusion-v1-5:tokenizer/`
样例数据集下载:
```shell
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/*" --local_dir ./data/diffsynth_example_dataset
```
[stable-diffusion-v1-5 训练脚本](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/stable_diffusion/model_training/lora/stable-diffusion-v1-5.sh)
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。

View File

@@ -32,8 +32,7 @@
Model_Details/LTX-2 Model_Details/LTX-2
Model_Details/ERNIE-Image Model_Details/ERNIE-Image
Model_Details/JoyAI-Image Model_Details/JoyAI-Image
Model_Details/Stable-Diffusion Model_Details/ACE-Step
Model_Details/Stable-Diffusion-XL
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2

View File

@@ -0,0 +1,53 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
# input audio codes as reference
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
)
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
audio_code_string = f.read().strip()
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="cover",
audio_code_string=audio_code_string,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo5-with-audio-codes.wav")

View File

@@ -0,0 +1,45 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base-CoverTask/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-CoverTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="cover",
src_audio=src_audio,
audio_cover_strength=0.5,
denoising_strength=0.9,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-cover.wav")

View File

@@ -0,0 +1,47 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base-RepaintTask/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-RepaintTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
# Repainting strength denotes the intensity of repainting area, where 0 means no repainting (keep the original audio) and 1 means full repainting.
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="repaint",
src_audio=src_audio,
repainting_ranges=[(-10, 30), (150, 200)],
repainting_strength=1.0,
duration=210,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-repaint.wav")

View File

@@ -0,0 +1,31 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base.wav")

View File

@@ -0,0 +1,38 @@
"""
Ace-Step 1.5 SFT (supervised fine-tuned) — Text-to-Music inference example.
SFT variant is fine-tuned for specific music styles.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft.wav")

View File

@@ -0,0 +1,36 @@
"""
Ace-Step 1.5 Turbo (continuous, shift 1-5) — Text-to-Music inference example.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
Continuous variant: handles shift range internally, no shift parameter needed.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous.wav")

View File

@@ -0,0 +1,36 @@
"""
Ace-Step 1.5 Turbo (shift=1) — Text-to-Music inference example.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
shift=1: default value, no need to pass.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1.wav")

View File

@@ -0,0 +1,37 @@
"""
Ace-Step 1.5 Turbo (shift=3) — Text-to-Music inference example.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
shift=3: explicitly passed for this variant.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
shift=3,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3.wav")

View File

@@ -0,0 +1,38 @@
"""
Ace-Step 1.5 XL Base (32 layers, hidden_size=2560) — Text-to-Music inference example.
XL variant with larger capacity for higher quality generation.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base.wav")

View File

@@ -0,0 +1,37 @@
"""
Ace-Step 1.5 XL SFT (32 layers, supervised fine-tuned) — Text-to-Music inference example.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft.wav")

View File

@@ -0,0 +1,36 @@
"""
Ace-Step 1.5 XL Turbo (32 layers, fast generation) — Text-to-Music inference example.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
shift=3: explicitly passed for this variant.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo.wav")

View File

@@ -0,0 +1,73 @@
"""
Ace-Step 1.5 (main model, turbo) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Turbo model: uses num_inference_steps=8, cfg_scale=1.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from modelscope import dataset_snapshot_download
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-low-vram.wav")
# input audio codes as reference
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
)
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
audio_code_string = f.read().strip()
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="cover",
audio_code_string=audio_code_string,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo5-with-audio-codes-low-vram.wav")

View File

@@ -0,0 +1,57 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base-CoverTask/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-CoverTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="cover",
src_audio=src_audio,
audio_cover_strength=0.5,
denoising_strength=0.9,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-cover.wav")

View File

@@ -0,0 +1,59 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio, read_audio
from modelscope import dataset_snapshot_download
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="ace_step/acestep-v15-base-RepaintTask/audio.wav",
)
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-RepaintTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
# Repainting strength denotes the intensity of repainting area, where 0 means no repainting (keep the original audio) and 1 means full repainting.
audio = pipe(
prompt=prompt,
lyrics=lyrics,
task_type="repaint",
src_audio=src_audio,
repainting_ranges=[(-10, 30), (150, 200)],
repainting_strength=1.0,
duration=210,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-repaint.wav")

View File

@@ -0,0 +1,49 @@
"""
Ace-Step 1.5 Base — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-low-vram.wav")

View File

@@ -0,0 +1,51 @@
"""
Ace-Step 1.5 SFT (supervised fine-tuned) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
SFT variant is fine-tuned for specific music styles.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft-low-vram.wav")

View File

@@ -0,0 +1,49 @@
"""
Ace-Step 1.5 Turbo (continuous, shift 1-5) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
Continuous variant: handles shift range internally, no shift parameter needed.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous-low-vram.wav")

View File

@@ -0,0 +1,49 @@
"""
Ace-Step 1.5 Turbo (shift=1) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
shift=1: default value, no need to pass.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1-low-vram.wav")

View File

@@ -0,0 +1,50 @@
"""
Ace-Step 1.5 Turbo (shift=3) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
shift=3: explicitly passed for this variant.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
shift=3,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3-low-vram.wav")

View File

@@ -0,0 +1,51 @@
"""
Ace-Step 1.5 XL Base — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
torch.cuda.reset_peak_memory_stats("cuda")
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base-low-vram.wav")

View File

@@ -0,0 +1,50 @@
"""
Ace-Step 1.5 XL SFT (supervised fine-tuned) — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft-low-vram.wav")

View File

@@ -0,0 +1,48 @@
"""
Ace-Step 1.5 XL Turbo — Text-to-Music inference example (Low VRAM).
Low VRAM version: models are offloaded to CPU and loaded on-demand.
Turbo model: no num_inference_steps or cfg_scale (use defaults).
"""
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo-low-vram.wav")

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/Ace-Step1.5/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/Ace-Step1.5/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5/metadata.json" \
--model_id_with_origin_paths "ACE-Step/Ace-Step1.5:acestep-v15-turbo/model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/Ace-Step1.5_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-base/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-base/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-base:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-base_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-sft/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-sft/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-sft:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-sft_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-continuous/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-continuous:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-continuous_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift1/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift1:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-shift1_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift3/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift3:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-shift3_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-base/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-base:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-base_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-sft/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-sft:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-sft_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,18 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-turbo/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "dit" \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-turbo:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-turbo_full" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/Ace-Step1.5/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/Ace-Step1.5/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5/metadata.json" \
--model_id_with_origin_paths "ACE-Step/Ace-Step1.5:acestep-v15-turbo/model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/Ace-Step1.5_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-base/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-base/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-base:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-base_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-sft/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-sft/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-sft:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-sft_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-continuous/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-continuous:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-continuous_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift1/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift1:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-shift1_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift3/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift3:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-turbo-shift3_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-base/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-base:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-base_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-sft/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-sft:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-sft_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -0,0 +1,20 @@
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-turbo/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/ace_step/model_training/train.py \
--learning_rate 1e-4 \
--num_epochs 20 \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters \
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo" \
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/metadata.json" \
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-turbo:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
--lora_base_model "dit" \
--remove_prefix_in_ckpt "pipe.dit." \
--dataset_repeat 50 \
--output_path "./models/train/acestep-v15-xl-turbo_lora" \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
--data_file_keys "audio"

View File

@@ -1,15 +1,20 @@
import torch, os, argparse, accelerate import os
import torch
import math
import argparse
import accelerate
from diffsynth.core import UnifiedDataset from diffsynth.core import UnifiedDataset
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig from diffsynth.core.data.operators import ToAbsolutePath, LoadPureAudioWithTorchaudio
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.diffusion import * from diffsynth.diffusion import *
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
class StableDiffusionTrainingModule(DiffusionTrainingModule): class AceStepTrainingModule(DiffusionTrainingModule):
def __init__( def __init__(
self, self,
model_paths=None, model_id_with_origin_paths=None, model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None, tokenizer_path=None, silence_latent_path=None,
trainable_models=None, trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None, preset_lora_path=None, preset_lora_model=None,
@@ -22,13 +27,15 @@ class StableDiffusionTrainingModule(DiffusionTrainingModule):
task="sft", task="sft",
): ):
super().__init__() super().__init__()
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device) model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/")) text_tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"))
self.pipe = StableDiffusionPipeline.from_pretrained(torch_dtype=torch.float32, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config) silence_latent_config = self.parse_path_or_model_id(silence_latent_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"))
self.pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16, device=device, model_configs=model_configs,
text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config,
)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model) self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# Training mode
self.switch_pipe_to_training_mode( self.switch_pipe_to_training_mode(
self.pipe, trainable_models, self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
@@ -36,7 +43,6 @@ class StableDiffusionTrainingModule(DiffusionTrainingModule):
task=task, task=task,
) )
# Other configs
self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
@@ -44,24 +50,23 @@ class StableDiffusionTrainingModule(DiffusionTrainingModule):
self.task = task self.task = task
self.task_to_loss = { self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args, "sft:data_process": lambda pipe, *args: args,
"direct_distill:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), "sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
} }
def get_pipeline_inputs(self, data): def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"]} inputs_posi = {"prompt": data["prompt"], "positive": True}
inputs_nega = {"negative_prompt": ""} inputs_nega = {"positive": False}
duration = math.floor(data['audio'][0].shape[1] / data['audio'][1]) if data.get("audio") is not None else data.get("duration", 60)
inputs_shared = { inputs_shared = {
# Assume you are using this pipeline for inference, "input_audio": data["audio"],
# please fill in the input parameters. "lyrics": data["lyrics"],
"input_image": data["image"], "task_type": "text2music",
"height": data["image"].size[1], "duration": duration,
"width": data["image"].size[0], "bpm": data.get("bpm", 100),
# Please do not modify the following parameters "keyscale": data.get("keyscale", "C major"),
# unless you clearly know what this will cause. "timesignature": data.get("timesignature", "4"),
"vocal_language": data.get("vocal_language", "unknown"),
"cfg_scale": 1, "cfg_scale": 1,
"rand_device": self.pipe.device, "rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing, "use_gradient_checkpointing": self.use_gradient_checkpointing,
@@ -79,16 +84,17 @@ class StableDiffusionTrainingModule(DiffusionTrainingModule):
return loss return loss
def parser(): def ace_step_parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="ACE-Step training.")
parser = add_general_config(parser) parser = add_general_config(parser)
parser = add_image_size_config(parser) parser.add_argument("--tokenizer_path", type=str, default=None, help="Tokenizer path in format model_id:origin_pattern.")
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.") parser.add_argument("--silence_latent_path", type=str, default=None, help="Silence latent path in format model_id:origin_pattern.")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
return parser return parser
if __name__ == "__main__": if __name__ == "__main__":
parser = parser() parser = ace_step_parser()
args = parser.parse_args() args = parser.parse_args()
accelerator = accelerate.Accelerator( accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
@@ -99,19 +105,18 @@ if __name__ == "__main__":
metadata_path=args.dataset_metadata_path, metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat, repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","), data_file_keys=args.data_file_keys.split(","),
main_data_operator=UnifiedDataset.default_image_operator( main_data_operator=None,
base_path=args.dataset_base_path, special_operator_map={
max_pixels=args.max_pixels, "audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(
height=args.height, target_sample_rate=48000,
width=args.width, ),
height_division_factor=32, },
width_division_factor=32,
) )
) model = AceStepTrainingModule(
model = StableDiffusionTrainingModule(
model_paths=args.model_paths, model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths, model_id_with_origin_paths=args.model_id_with_origin_paths,
tokenizer_path=args.tokenizer_path, tokenizer_path=args.tokenizer_path,
silence_latent_path=args.silence_latent_path,
trainable_models=args.trainable_models, trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model, lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules, lora_target_modules=args.lora_target_modules,
@@ -125,7 +130,7 @@ if __name__ == "__main__":
fp8_models=args.fp8_models, fp8_models=args.fp8_models,
offload_models=args.offload_models, offload_models=args.offload_models,
task=args.task, task=args.task,
device=accelerator.device, device="cpu" if args.initialize_model_on_cpu else accelerator.device,
) )
model_logger = ModelLogger( model_logger = ModelLogger(
args.output_path, args.output_path,
@@ -133,10 +138,7 @@ if __name__ == "__main__":
) )
launcher_map = { launcher_map = {
"sft:data_process": launch_data_process_task, "sft:data_process": launch_data_process_task,
"direct_distill:data_process": launch_data_process_task,
"sft": launch_training_task, "sft": launch_training_task,
"sft:train": launch_training_task, "sft:train": launch_training_task,
"direct_distill": launch_training_task,
"direct_distill:train": launch_training_task,
} }
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/Ace-Step1.5_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "Ace-Step1.5_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-base_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=42,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-sft_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-turbo-continuous_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-turbo-shift1_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-turbo-shift3_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-xl-base_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-xl-sft_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft_full.wav")

View File

@@ -0,0 +1,35 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
from diffsynth import load_state_dict
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
state_dict = load_state_dict("models/train/acestep-v15-xl-turbo_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=8,
cfg_scale=1.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo_full.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/Ace-Step1.5_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "Ace-Step1.5_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-base_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-sft_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-continuous_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-shift1_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-shift3_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-base_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-sft_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft_lora.wav")

View File

@@ -0,0 +1,33 @@
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
from diffsynth.utils.data.audio import save_audio
import torch
pipe = AceStepPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
)
pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-turbo_lora/epoch-9.safetensors", alpha=1)
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
audio = pipe(
prompt=prompt,
lyrics=lyrics,
duration=160,
bpm=100,
keyscale="B minor",
timesignature="4",
vocal_language="zh",
seed=1,
num_inference_steps=30,
cfg_scale=4.0,
)
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo_lora.wav")

View File

@@ -1,25 +0,0 @@
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")

View File

@@ -1,36 +0,0 @@
import torch
from diffsynth.core import ModelConfig
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline
vram_config = {
"offload_dtype": torch.float32,
"offload_device": "cpu",
"onload_dtype": torch.float32,
"onload_device": "cpu",
"preparing_dtype": torch.float32,
"preparing_device": "cuda",
"computation_dtype": torch.float32,
"computation_device": "cuda",
}
pipe = StableDiffusionPipeline.from_pretrained(
torch_dtype=torch.float32,
model_configs=[
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
],
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
)
image = pipe(
prompt="a photo of an astronaut riding a horse on mars, high quality, detailed",
negative_prompt="blurry, low quality, deformed",
cfg_scale=7.5,
height=512,
width=512,
seed=42,
rand_device="cuda",
num_inference_steps=50,
)
image.save("image.jpg")

View File

@@ -1,15 +0,0 @@
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/stable-diffusion-v1-5/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/stable_diffusion/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5 \
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5/metadata.csv \
--height 512 \
--width 512 \
--dataset_repeat 50 \
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--trainable_models "unet" \
--remove_prefix_in_ckpt "pipe.unet." \
--output_path "./models/train/stable-diffusion-v1-5_full" \
--use_gradient_checkpointing

View File

@@ -1,17 +0,0 @@
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/stable-diffusion-v1-5/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch examples/stable_diffusion/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5 \
--dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5/metadata.csv \
--height 512 \
--width 512 \
--dataset_repeat 50 \
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.unet." \
--output_path "./models/train/stable-diffusion-v1-5_lora" \
--lora_base_model "unet" \
--lora_target_modules "" \
--lora_rank 32 \
--use_gradient_checkpointing

Some files were not shown because too many files have changed in this diff Show More