mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
Compare commits
1 Commits
ace-step
...
version-2.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3ca130bbd |
229
README.md
229
README.md
@@ -7,7 +7,6 @@
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||
[](https://discord.gg/Mm9suEeUDc)
|
||||
|
||||
[切换到中文版](./README_zh.md)
|
||||
|
||||
@@ -33,11 +32,6 @@ We believe that a well-developed open-source code framework can lower the thresh
|
||||
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
|
||||
|
||||
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher) and [mi804](https://github.com/mi804). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
||||
|
||||
- **April 23, 2026** ACE-Step open-sourced, welcome a new member to the audio model family! Support includes text-to-music generation, low VRAM inference, and LoRA training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/ACE-Step.md) and [example code](/examples/ace_step/).
|
||||
|
||||
- **April 14, 2026** JoyAI-Image open-sourced, welcome a new member to the image editing model family! Support includes instruction-guided image editing, low VRAM inference, and training capabilities. For details, please refer to the [documentation](/docs/en/Model_Details/JoyAI-Image.md) and [example code](/examples/joyai_image/).
|
||||
|
||||
- **March 19, 2026**: Added support for [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) and [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/Wan.md) and [example code](/examples/mova/) are now available.
|
||||
|
||||
- **March 12, 2026**: We have added support for the [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) audio-video generation model. The features includes text-to-audio/video, image-to-audio/video, IC-LoRA control, audio-to-video, and audio-video inpainting. We have supported the complete inference and training functionalities. For details, please refer to the [documentation](/docs/en/Model_Details/LTX-2.md) and [code](/examples/ltx2/).
|
||||
@@ -602,143 +596,6 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
|
||||
|
||||
</details>
|
||||
|
||||
#### ERNIE-Image: [/docs/en/Model_Details/ERNIE-Image.md](/docs/en/Model_Details/ERNIE-Image.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
Running the following code will quickly load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = ErnieImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device='cuda',
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="一只黑白相间的中华田园犬",
|
||||
negative_prompt="",
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=42,
|
||||
num_inference_steps=50,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
image.save("output.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Examples</summary>
|
||||
|
||||
Example code for ERNIE-Image is available at: [/examples/ernie_image/](/examples/ernie_image/)
|
||||
|
||||
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
||||
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
||||
|
||||
</details>
|
||||
|
||||
#### JoyAI-Image: [/docs/en/Model_Details/JoyAI-Image.md](/docs/en/Model_Details/JoyAI-Image.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
Running the following code will quickly load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 4GB VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
# Download dataset
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||
local_dir="data/diffsynth_example_dataset",
|
||||
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||
)
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
pipe = JoyAIImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
# Use first sample from dataset
|
||||
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||
prompt = "将裙子改为粉色"
|
||||
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
edit_image=edit_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=0,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=5.0,
|
||||
)
|
||||
|
||||
output.save("output_joyai_edit_low_vram.png")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Examples</summary>
|
||||
|
||||
Example code for JoyAI-Image is available at: [/examples/joyai_image/](/examples/joyai_image/)
|
||||
|
||||
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
|
||||
|
||||
</details>
|
||||
|
||||
### Video Synthesis
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
@@ -1018,86 +875,6 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
|
||||
|
||||
</details>
|
||||
|
||||
### Audio Synthesis
|
||||
|
||||
#### ACE-Step: [/docs/en/Model_Details/ACE-Step.md](/docs/en/Model_Details/ACE-Step.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
Running the following code will quickly load the [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) model and perform inference. VRAM management is enabled, and the framework will automatically control the loading of model parameters based on available VRAM. The model can run with a minimum of 3GB VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Examples</summary>
|
||||
|
||||
Example code for ACE-Step is available at: [/examples/ace_step/](/examples/ace_step/)
|
||||
|
||||
| Model ID | Inference | Low VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|
||||
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|
||||
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|
||||
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|
||||
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|
||||
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|
||||
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|
||||
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
|
||||
|
||||
</details>
|
||||
|
||||
## Innovative Achievements
|
||||
|
||||
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
|
||||
@@ -1252,9 +1029,3 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
</details>
|
||||
|
||||
## Contact Us
|
||||
|
||||
|Discord:https://discord.gg/Mm9suEeUDc|
|
||||
|-|
|
||||
|<img width="160" height="160" alt="Image" src="https://github.com/user-attachments/assets/29bdc97b-e35d-4fea-88d6-32e35182e458" />|
|
||||
|
||||
228
README_zh.md
228
README_zh.md
@@ -7,7 +7,6 @@
|
||||
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
||||
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
||||
[](https://discord.gg/Mm9suEeUDc)
|
||||
|
||||
[Switch to English](./README.md)
|
||||
|
||||
@@ -34,10 +33,6 @@ DiffSynth 目前包括两个开源项目:
|
||||
|
||||
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 和 [mi804](https://github.com/mi804) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||
|
||||
- **2026年4月23日** ACE-Step 开源,欢迎加入音频生成模型家族!支持文生音乐推理、低显存推理和 LoRA 训练能力。详情请参考[文档](/docs/zh/Model_Details/ACE-Step.md)和[示例代码](/examples/ace_step/)。
|
||||
|
||||
- **2026年4月14日** JoyAI-Image 开源,欢迎加入图像编辑模型家族!支持指令引导的图像编辑推理、低显存推理和训练能力。详情请参考[文档](/docs/zh/Model_Details/JoyAI-Image.md)和[示例代码](/examples/joyai_image/)。
|
||||
|
||||
- **2026年3月19日** 新增对 [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) 和 [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/Wan.md)和[示例代码](/examples/mova/)现已可用。
|
||||
|
||||
- **2026年3月12日** 我们新增了 [LTX-2.3](https://modelscope.cn/models/Lightricks/LTX-2.3) 音视频生成模型的支持,模型支持的功能包括文生音视频、图生音视频、IC-LoRA控制、音频生视频、音视频局部Inpainting,框架支持完整的推理和训练功能。详细信息请参考 [文档](/docs/zh/Model_Details/LTX-2.md) 和 [示例代码](/examples/ltx2/)。
|
||||
@@ -602,143 +597,6 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|
||||
|
||||
</details>
|
||||
|
||||
#### ERNIE-Image: [/docs/zh/Model_Details/ERNIE-Image.md](/docs/zh/Model_Details/ERNIE-Image.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = ErnieImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device='cuda',
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="一只黑白相间的中华田园犬",
|
||||
negative_prompt="",
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=42,
|
||||
num_inference_steps=50,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
image.save("output.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>示例代码</summary>
|
||||
|
||||
ERNIE-Image 的示例代码位于:[/examples/ernie_image/](/examples/ernie_image/)
|
||||
|
||||
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
||||
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
||||
|
||||
</details>
|
||||
|
||||
#### JoyAI-Image: [/docs/zh/Model_Details/JoyAI-Image.md](/docs/zh/Model_Details/JoyAI-Image.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
# Download dataset
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||
local_dir="data/diffsynth_example_dataset",
|
||||
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||
)
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
pipe = JoyAIImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
# Use first sample from dataset
|
||||
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||
prompt = "将裙子改为粉色"
|
||||
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
edit_image=edit_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=0,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=5.0,
|
||||
)
|
||||
|
||||
output.save("output_joyai_edit_low_vram.png")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>示例代码</summary>
|
||||
|
||||
JoyAI-Image 的示例代码位于:[/examples/joyai_image/](/examples/joyai_image/)
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
|
||||
|
||||
</details>
|
||||
|
||||
### 视频生成模型
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
@@ -1018,86 +876,6 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|
||||
|
||||
</details>
|
||||
|
||||
### 音频生成模型
|
||||
|
||||
#### ACE-Step: [/docs/zh/Model_Details/ACE-Step.md](/docs/zh/Model_Details/ACE-Step.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
运行以下代码可以快速加载 [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>示例代码</summary>
|
||||
|
||||
ACE-Step 的示例代码位于:[/examples/ace_step/](/examples/ace_step/)
|
||||
|
||||
| 模型 ID | 推理 | 低显存推理 | 全量训练 | 全量训练后验证 | LoRA 训练 | LoRA 训练后验证 |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|
||||
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|
||||
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|
||||
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|
||||
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|
||||
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|
||||
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|
||||
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
|
||||
|
||||
</details>
|
||||
|
||||
## 创新成果
|
||||
|
||||
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
|
||||
@@ -1254,9 +1032,3 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
</details>
|
||||
|
||||
## 联系我们
|
||||
|
||||
|Discord:https://discord.gg/Mm9suEeUDc|
|
||||
|-|
|
||||
|<img width="160" height="160" alt="Image" src="https://github.com/user-attachments/assets/29bdc97b-e35d-4fea-88d6-32e35182e458" />|
|
||||
|
||||
@@ -541,22 +541,6 @@ flux2_series = [
|
||||
},
|
||||
]
|
||||
|
||||
ernie_image_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
||||
"model_hash": "584c13713849f1af4e03d5f1858b8b7b",
|
||||
"model_name": "ernie_image_dit",
|
||||
"model_class": "diffsynth.models.ernie_image_dit.ErnieImageDiT",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors")
|
||||
"model_hash": "404ed9f40796a38dd34c1620f1920207",
|
||||
"model_name": "ernie_image_text_encoder",
|
||||
"model_class": "diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ernie_image_text_encoder.ErnieImageTextEncoderStateDictConverter",
|
||||
},
|
||||
]
|
||||
|
||||
z_image_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
||||
@@ -900,102 +884,4 @@ mova_series = [
|
||||
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
|
||||
},
|
||||
]
|
||||
joyai_image_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth")
|
||||
"model_hash": "56592ddfd7d0249d3aa527d24161a863",
|
||||
"model_name": "joyai_image_dit",
|
||||
"model_class": "diffsynth.models.joyai_image_dit.JoyAIImageDiT",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model-*.safetensors")
|
||||
"model_hash": "2d11bf14bba8b4e87477c8199a895403",
|
||||
"model_name": "joyai_image_text_encoder",
|
||||
"model_class": "diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.joyai_image_text_encoder.JoyAIImageTextEncoderStateDictConverter",
|
||||
},
|
||||
]
|
||||
|
||||
ace_step_series = [
|
||||
# === Standard DiT variants (24 layers, hidden_size=2048) ===
|
||||
# Covers: turbo, turbo-shift1, turbo-shift3, turbo-continuous, base, sft
|
||||
# All share identical state_dict structure → same hash
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
|
||||
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||
"model_name": "ace_step_dit",
|
||||
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
|
||||
},
|
||||
# === XL DiT variants (32 layers, hidden_size=2560) ===
|
||||
# Covers: xl-base, xl-sft, xl-turbo
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
|
||||
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||
"model_name": "ace_step_dit",
|
||||
"model_class": "diffsynth.models.ace_step_dit.AceStepDiTModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTModelStateDictConverter",
|
||||
"extra_kwargs": {
|
||||
"hidden_size": 2560,
|
||||
"intermediate_size": 9728,
|
||||
"num_hidden_layers": 32,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"head_dim": 128,
|
||||
"encoder_hidden_size": 2048,
|
||||
"layer_types": ["sliding_attention", "full_attention"] * 16,
|
||||
},
|
||||
},
|
||||
# === Conditioner (shared by all DiT variants, same architecture) ===
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
|
||||
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||
"model_name": "ace_step_conditioner",
|
||||
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
|
||||
},
|
||||
# === XL Conditioner (same architecture, but checkpoint includes XL decoder → different file hash) ===
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
|
||||
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||
"model_name": "ace_step_conditioner",
|
||||
"model_class": "diffsynth.models.ace_step_conditioner.AceStepConditionEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_conditioner.AceStepConditionEncoderStateDictConverter",
|
||||
},
|
||||
# === Qwen3-Embedding (text encoder) ===
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors")
|
||||
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
|
||||
"model_name": "ace_step_text_encoder",
|
||||
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter",
|
||||
},
|
||||
# === VAE (AutoencoderOobleck CNN) ===
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
||||
"model_hash": "51420834e54474986a7f4be0e4d6f687",
|
||||
"model_name": "ace_step_vae",
|
||||
"model_class": "diffsynth.models.ace_step_vae.AceStepVAE",
|
||||
},
|
||||
# === Tokenizer (VAE latent discretization: tokenizer + detokenizer) ===
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
|
||||
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
|
||||
"model_name": "ace_step_tokenizer",
|
||||
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
|
||||
},
|
||||
# === XL Tokenizer (XL models share same tokenizer architecture) ===
|
||||
{
|
||||
# Example: ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors")
|
||||
"model_hash": "3a28a410c2246f125153ef792d8bc828",
|
||||
"model_name": "ace_step_tokenizer",
|
||||
"model_class": "diffsynth.models.ace_step_tokenizer.AceStepTokenizer",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_tokenizer.AceStepTokenizerStateDictConverter",
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = (
|
||||
qwen_image_series + wan_series + flux_series + flux2_series + ernie_image_series
|
||||
+ z_image_series + ltx2_series + anima_series + mova_series + joyai_image_series + ace_step_series
|
||||
)
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series
|
||||
|
||||
@@ -267,72 +267,6 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ernie_image_dit.ErnieImageDiT": {
|
||||
"diffsynth.models.ernie_image_dit.ErnieImageRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ernie_image_text_encoder.ErnieImageTextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.ministral3.modeling_ministral3.Ministral3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.joyai_image_dit.Transformer3DModel": {
|
||||
"diffsynth.models.joyai_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.joyai_image_dit.ModulateWan": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.joyai_image_text_encoder.JoyAIImageTextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionModel": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
# ACE-Step module maps
|
||||
"diffsynth.models.ace_step_dit.AceStepDiTModel": {
|
||||
"diffsynth.models.ace_step_dit.AceStepDiTLayer": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ace_step_conditioner.AceStepConditionEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ace_step_text_encoder.AceStepTextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ace_step_vae.AceStepVAE": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.ace_step_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ace_step_tokenizer.AceStepTokenizer": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"vector_quantize_pytorch.ResidualFSQ": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
}
|
||||
|
||||
def QwenImageTextEncoder_Module_Map_Updater():
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import math, warnings
|
||||
import math
|
||||
import torch, torchvision, imageio, os
|
||||
import imageio.v3 as iio
|
||||
from PIL import Image
|
||||
import torchaudio
|
||||
from diffsynth.utils.data.audio import read_audio
|
||||
|
||||
|
||||
class DataProcessingPipeline:
|
||||
@@ -261,43 +260,15 @@ class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin):
|
||||
FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate)
|
||||
|
||||
def __call__(self, data: str):
|
||||
try:
|
||||
reader = self.get_reader(data)
|
||||
num_frames = self.get_num_frames(reader)
|
||||
duration = num_frames / self.frame_rate
|
||||
waveform, sample_rate = torchaudio.load(data)
|
||||
target_samples = int(duration * sample_rate)
|
||||
current_samples = waveform.shape[-1]
|
||||
if current_samples > target_samples:
|
||||
waveform = waveform[..., :target_samples]
|
||||
elif current_samples < target_samples:
|
||||
padding = target_samples - current_samples
|
||||
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||
return waveform, sample_rate
|
||||
except:
|
||||
warnings.warn(f"Cannot load audio in {data}. The audio will be `None`.")
|
||||
return None
|
||||
|
||||
|
||||
class LoadPureAudioWithTorchaudio(DataProcessingOperator):
|
||||
|
||||
def __init__(self, target_sample_rate=None, target_duration=None):
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.target_duration = target_duration
|
||||
self.resample = True if target_sample_rate is not None else False
|
||||
|
||||
def __call__(self, data: str):
|
||||
try:
|
||||
waveform, sample_rate = read_audio(data, resample=self.resample, resample_rate=self.target_sample_rate)
|
||||
if self.target_duration is not None:
|
||||
target_samples = int(self.target_duration * sample_rate)
|
||||
current_samples = waveform.shape[-1]
|
||||
if current_samples > target_samples:
|
||||
waveform = waveform[..., :target_samples]
|
||||
elif current_samples < target_samples:
|
||||
padding = target_samples - current_samples
|
||||
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||
return waveform, sample_rate
|
||||
except Exception as e:
|
||||
warnings.warn(f"Cannot load audio in '{data}' due to '{e}'. The audio will be `None`.")
|
||||
return None
|
||||
reader = self.get_reader(data)
|
||||
num_frames = self.get_num_frames(reader)
|
||||
duration = num_frames / self.frame_rate
|
||||
waveform, sample_rate = torchaudio.load(data)
|
||||
target_samples = int(duration * sample_rate)
|
||||
current_samples = waveform.shape[-1]
|
||||
if current_samples > target_samples:
|
||||
waveform = waveform[..., :target_samples]
|
||||
elif current_samples < target_samples:
|
||||
padding = target_samples - current_samples
|
||||
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
||||
return waveform, sample_rate
|
||||
|
||||
@@ -152,7 +152,7 @@ class BasePipeline(torch.nn.Module):
|
||||
# remove batch dim
|
||||
if audio_output.ndim == 3:
|
||||
audio_output = audio_output.squeeze(0)
|
||||
return audio_output.float().cpu()
|
||||
return audio_output.float()
|
||||
|
||||
def load_models_to_device(self, model_names):
|
||||
if self.vram_management_enabled:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing_extensions import Literal
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image", "ACE-Step"] = "FLUX.1"):
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
||||
self.set_timesteps_fn = {
|
||||
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||
@@ -13,8 +13,6 @@ class FlowMatchScheduler():
|
||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
||||
"ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
|
||||
"ACE-Step": FlowMatchScheduler.set_timesteps_ace_step,
|
||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||
self.num_train_timesteps = 1000
|
||||
|
||||
@@ -131,38 +129,6 @@ class FlowMatchScheduler():
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0, shift=3.0):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
if shift is not None and shift != 1.0:
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_ace_step(num_inference_steps=8, denoising_strength=1.0, shift=3.0):
|
||||
"""ACE-Step Flow Matching scheduler.
|
||||
|
||||
Timesteps range from 1.0 to 0.0 (not multiplied by 1000).
|
||||
Shift transformation: t = shift * t / (1 + (shift - 1) * t)
|
||||
|
||||
Args:
|
||||
num_inference_steps: Number of diffusion steps.
|
||||
denoising_strength: Denoising strength (1.0 = full denoising).
|
||||
shift: Timestep shift parameter (default 3.0 for turbo).
|
||||
"""
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
|
||||
if shift is not None and shift != 1.0:
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
||||
sigma_min = 0.0
|
||||
@@ -180,18 +146,6 @@ class FlowMatchScheduler():
|
||||
timesteps[timestep_id] = timestep
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_joyai_image(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
shift = 4.0 if shift is None else shift
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
||||
num_train_timesteps = 1000
|
||||
@@ -231,7 +185,7 @@ class FlowMatchScheduler():
|
||||
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
||||
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
||||
num_inference_steps=num_inference_steps,
|
||||
|
||||
@@ -33,15 +33,15 @@ def launch_training_task(
|
||||
for epoch_id in range(num_epochs):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
if dataset.load_from_cache:
|
||||
loss = model({}, inputs=data)
|
||||
else:
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||
scheduler.step()
|
||||
if save_steps is None:
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
model_logger.on_training_end(accelerator, model, save_steps)
|
||||
|
||||
@@ -1,695 +0,0 @@
|
||||
# Copyright 2025 The ACESTEO Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import can_return_tuple, logging
|
||||
from transformers.models.qwen3.modeling_qwen3 import (
|
||||
Qwen3MLP,
|
||||
Qwen3RMSNorm,
|
||||
Qwen3RotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def create_4d_mask(
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
is_sliding_window: bool = False,
|
||||
is_causal: bool = True,
|
||||
) -> torch.Tensor:
|
||||
indices = torch.arange(seq_len, device=device)
|
||||
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
|
||||
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
|
||||
if is_causal:
|
||||
valid_mask = valid_mask & (diff >= 0)
|
||||
if is_sliding_window and sliding_window is not None:
|
||||
if is_causal:
|
||||
valid_mask = valid_mask & (diff <= sliding_window)
|
||||
else:
|
||||
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
|
||||
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
|
||||
if attention_mask is not None:
|
||||
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
|
||||
valid_mask = valid_mask & padding_mask_4d
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
|
||||
mask_tensor.masked_fill_(valid_mask, 0.0)
|
||||
return mask_tensor
|
||||
|
||||
|
||||
def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
|
||||
hidden_cat = torch.cat([hidden1, hidden2], dim=1)
|
||||
mask_cat = torch.cat([mask1, mask2], dim=1)
|
||||
B, L, D = hidden_cat.shape
|
||||
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True)
|
||||
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
|
||||
lengths = mask_cat.sum(dim=1)
|
||||
new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
|
||||
return hidden_left, new_mask
|
||||
|
||||
|
||||
class Lambda(nn.Module):
|
||||
def __init__(self, func):
|
||||
super().__init__()
|
||||
self.func = func
|
||||
|
||||
def forward(self, x):
|
||||
return self.func(x)
|
||||
|
||||
|
||||
class AceStepAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
rms_norm_eps: float,
|
||||
attention_bias: bool,
|
||||
attention_dropout: float,
|
||||
layer_types: list,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
layer_idx: int = 0,
|
||||
is_cross_attention: bool = False,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.num_key_value_groups = num_attention_heads // num_key_value_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.attention_dropout = attention_dropout
|
||||
if is_cross_attention:
|
||||
is_causal = False
|
||||
self.is_causal = is_causal
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
|
||||
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.attention_type = layer_types[layer_idx]
|
||||
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
|
||||
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention:
|
||||
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
if not is_updated:
|
||||
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
else:
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||
|
||||
else:
|
||||
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
if self.num_key_value_groups > 1:
|
||||
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||
|
||||
attn_output = attention_forward(
|
||||
query_states, key_states, value_states,
|
||||
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
|
||||
attn_mask=attention_mask,
|
||||
)
|
||||
attn_weights = None
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AceStepEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
rms_norm_eps: float,
|
||||
attention_bias: bool,
|
||||
attention_dropout: float,
|
||||
layer_types: list,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
layer_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = AceStepAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
layer_idx=layer_idx,
|
||||
is_cross_attention=False,
|
||||
is_causal=False,
|
||||
)
|
||||
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
mlp_config = type('Config', (), {
|
||||
'hidden_size': hidden_size,
|
||||
'intermediate_size': intermediate_size,
|
||||
'hidden_act': 'silu',
|
||||
})()
|
||||
self.mlp = Qwen3MLP(mlp_config)
|
||||
self.attention_type = layer_types[layer_idx]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
past_key_value=None,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
return outputs
|
||||
|
||||
|
||||
class AceStepLyricEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
layer_types: Optional[list] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = 128,
|
||||
use_sliding_window: bool = True,
|
||||
use_cache: bool = True,
|
||||
rope_theta: float = 1000000,
|
||||
max_position_embeddings: int = 32768,
|
||||
initializer_range: float = 0.02,
|
||||
text_hidden_dim: int = 1024,
|
||||
num_lyric_encoder_hidden_layers: int = 8,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
|
||||
self.text_hidden_dim = text_hidden_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||
|
||||
self.embed_tokens = nn.Linear(text_hidden_dim, hidden_size)
|
||||
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
rope_config = type('RopeConfig', (), {
|
||||
'hidden_size': hidden_size,
|
||||
'num_attention_heads': num_attention_heads,
|
||||
'num_key_value_heads': num_key_value_heads,
|
||||
'head_dim': head_dim,
|
||||
'max_position_embeddings': max_position_embeddings,
|
||||
'rope_theta': rope_theta,
|
||||
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
|
||||
'rms_norm_eps': rms_norm_eps,
|
||||
'attention_bias': attention_bias,
|
||||
'attention_dropout': attention_dropout,
|
||||
'hidden_act': 'silu',
|
||||
'intermediate_size': intermediate_size,
|
||||
'layer_types': self.layer_types,
|
||||
'sliding_window': sliding_window,
|
||||
'_attn_implementation': self._attn_implementation,
|
||||
})()
|
||||
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
AceStepEncoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=self.layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
for layer_idx in range(num_lyric_encoder_hidden_layers)
|
||||
])
|
||||
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutput:
|
||||
output_attentions = output_attentions if output_attentions is not None else False
|
||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
|
||||
|
||||
assert input_ids is None, "Only `inputs_embeds` is supported for the lyric encoder."
|
||||
assert attention_mask is not None, "Attention mask must be provided for the lyric encoder."
|
||||
assert inputs_embeds is not None, "Inputs embeddings must be provided for the lyric encoder."
|
||||
|
||||
inputs_embeds = self.embed_tokens(inputs_embeds)
|
||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
dtype = inputs_embeds.dtype
|
||||
device = inputs_embeds.device
|
||||
|
||||
full_attn_mask = create_4d_mask(
|
||||
seq_len=seq_len, dtype=dtype, device=device,
|
||||
attention_mask=attention_mask, sliding_window=None,
|
||||
is_sliding_window=False, is_causal=False
|
||||
)
|
||||
sliding_attn_mask = None
|
||||
if self.use_sliding_window:
|
||||
sliding_attn_mask = create_4d_mask(
|
||||
seq_len=seq_len, dtype=dtype, device=device,
|
||||
attention_mask=attention_mask, sliding_window=self.sliding_window,
|
||||
is_sliding_window=True, is_causal=False
|
||||
)
|
||||
|
||||
self_attn_mask_mapping = {
|
||||
"full_attention": full_attn_mask,
|
||||
"sliding_attention": sliding_attn_mask,
|
||||
}
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for layer_module in self.layers[: self.num_lyric_encoder_hidden_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, position_embeddings,
|
||||
self_attn_mask_mapping[layer_module.attention_type],
|
||||
position_ids, output_attentions,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class AceStepTimbreEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
layer_types: Optional[list] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = 128,
|
||||
use_sliding_window: bool = True,
|
||||
use_cache: bool = True,
|
||||
rope_theta: float = 1000000,
|
||||
max_position_embeddings: int = 32768,
|
||||
initializer_range: float = 0.02,
|
||||
timbre_hidden_dim: int = 64,
|
||||
num_timbre_encoder_hidden_layers: int = 4,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.timbre_hidden_dim = timbre_hidden_dim
|
||||
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
|
||||
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||
|
||||
self.embed_tokens = nn.Linear(timbre_hidden_dim, hidden_size)
|
||||
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
rope_config = type('RopeConfig', (), {
|
||||
'hidden_size': hidden_size,
|
||||
'num_attention_heads': num_attention_heads,
|
||||
'num_key_value_heads': num_key_value_heads,
|
||||
'head_dim': head_dim,
|
||||
'max_position_embeddings': max_position_embeddings,
|
||||
'rope_theta': rope_theta,
|
||||
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
|
||||
'rms_norm_eps': rms_norm_eps,
|
||||
'attention_bias': attention_bias,
|
||||
'attention_dropout': attention_dropout,
|
||||
'hidden_act': 'silu',
|
||||
'intermediate_size': intermediate_size,
|
||||
'layer_types': self.layer_types,
|
||||
'sliding_window': sliding_window,
|
||||
'_attn_implementation': self._attn_implementation,
|
||||
})()
|
||||
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
|
||||
self.gradient_checkpointing = False
|
||||
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size))
|
||||
self.layers = nn.ModuleList([
|
||||
AceStepEncoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=self.layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
for layer_idx in range(num_timbre_encoder_hidden_layers)
|
||||
])
|
||||
|
||||
def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
|
||||
N, d = timbre_embs_packed.shape
|
||||
device = timbre_embs_packed.device
|
||||
dtype = timbre_embs_packed.dtype
|
||||
B = int(refer_audio_order_mask.max().item() + 1)
|
||||
counts = torch.bincount(refer_audio_order_mask, minlength=B)
|
||||
max_count = counts.max().item()
|
||||
sorted_indices = torch.argsort(refer_audio_order_mask * N + torch.arange(N, device=device), stable=True)
|
||||
sorted_batch_ids = refer_audio_order_mask[sorted_indices]
|
||||
positions = torch.arange(N, device=device)
|
||||
batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]])
|
||||
positions_in_sorted = positions - batch_starts[sorted_batch_ids]
|
||||
inverse_indices = torch.empty_like(sorted_indices)
|
||||
inverse_indices[sorted_indices] = torch.arange(N, device=device)
|
||||
positions_in_batch = positions_in_sorted[inverse_indices]
|
||||
indices_2d = refer_audio_order_mask * max_count + positions_in_batch
|
||||
one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(dtype)
|
||||
timbre_embs_flat = one_hot.t() @ timbre_embs_packed
|
||||
timbre_embs_unpack = timbre_embs_flat.reshape(B, max_count, d)
|
||||
mask_flat = (one_hot.sum(dim=0) > 0).long()
|
||||
new_mask = mask_flat.reshape(B, max_count)
|
||||
return timbre_embs_unpack, new_mask
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
refer_audio_acoustic_hidden_states_packed: Optional[torch.FloatTensor] = None,
|
||||
refer_audio_order_mask: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> BaseModelOutput:
|
||||
inputs_embeds = refer_audio_acoustic_hidden_states_packed
|
||||
inputs_embeds = self.embed_tokens(inputs_embeds)
|
||||
seq_len = inputs_embeds.shape[1]
|
||||
cache_position = torch.arange(0, seq_len, device=inputs_embeds.device)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
dtype = inputs_embeds.dtype
|
||||
device = inputs_embeds.device
|
||||
|
||||
full_attn_mask = create_4d_mask(
|
||||
seq_len=seq_len, dtype=dtype, device=device,
|
||||
attention_mask=attention_mask, sliding_window=None,
|
||||
is_sliding_window=False, is_causal=False
|
||||
)
|
||||
sliding_attn_mask = None
|
||||
if self.use_sliding_window:
|
||||
sliding_attn_mask = create_4d_mask(
|
||||
seq_len=seq_len, dtype=dtype, device=device,
|
||||
attention_mask=attention_mask, sliding_window=self.sliding_window,
|
||||
is_sliding_window=True, is_causal=False
|
||||
)
|
||||
|
||||
self_attn_mask_mapping = {
|
||||
"full_attention": full_attn_mask,
|
||||
"sliding_attention": sliding_attn_mask,
|
||||
}
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for layer_module in self.layers[: self.num_timbre_encoder_hidden_layers]:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, position_embeddings,
|
||||
self_attn_mask_mapping[layer_module.attention_type],
|
||||
position_ids,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states[:, 0, :]
|
||||
# For packed input: reshape [1, T, D] -> [T, D] for unpacking
|
||||
timbre_embs_unpack, timbre_embs_mask = self.unpack_timbre_embeddings(hidden_states, refer_audio_order_mask)
|
||||
return timbre_embs_unpack, timbre_embs_mask
|
||||
|
||||
|
||||
class AceStepConditionEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
layer_types: Optional[list] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = 128,
|
||||
use_sliding_window: bool = True,
|
||||
use_cache: bool = True,
|
||||
rope_theta: float = 1000000,
|
||||
max_position_embeddings: int = 32768,
|
||||
initializer_range: float = 0.02,
|
||||
text_hidden_dim: int = 1024,
|
||||
timbre_hidden_dim: int = 64,
|
||||
num_lyric_encoder_hidden_layers: int = 8,
|
||||
num_timbre_encoder_hidden_layers: int = 4,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.text_hidden_dim = text_hidden_dim
|
||||
self.timbre_hidden_dim = timbre_hidden_dim
|
||||
self.num_lyric_encoder_hidden_layers = num_lyric_encoder_hidden_layers
|
||||
self.num_timbre_encoder_hidden_layers = num_timbre_encoder_hidden_layers
|
||||
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||
|
||||
self.text_projector = nn.Linear(text_hidden_dim, hidden_size, bias=False)
|
||||
self.null_condition_emb = nn.Parameter(torch.randn(1, 1, hidden_size))
|
||||
self.lyric_encoder = AceStepLyricEncoder(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
use_sliding_window=use_sliding_window,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
text_hidden_dim=text_hidden_dim,
|
||||
num_lyric_encoder_hidden_layers=num_lyric_encoder_hidden_layers,
|
||||
)
|
||||
self.timbre_encoder = AceStepTimbreEncoder(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
use_sliding_window=use_sliding_window,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
timbre_hidden_dim=timbre_hidden_dim,
|
||||
num_timbre_encoder_hidden_layers=num_timbre_encoder_hidden_layers,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
text_attention_mask: Optional[torch.Tensor] = None,
|
||||
lyric_hidden_states: Optional[torch.LongTensor] = None,
|
||||
lyric_attention_mask: Optional[torch.Tensor] = None,
|
||||
reference_latents: Optional[torch.Tensor] = None,
|
||||
refer_audio_order_mask: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
text_hidden_states = self.text_projector(text_hidden_states)
|
||||
lyric_encoder_outputs = self.lyric_encoder(
|
||||
inputs_embeds=lyric_hidden_states,
|
||||
attention_mask=lyric_attention_mask,
|
||||
)
|
||||
lyric_hidden_states = lyric_encoder_outputs.last_hidden_state
|
||||
timbre_embs_unpack, timbre_embs_mask = self.timbre_encoder(reference_latents, refer_audio_order_mask)
|
||||
encoder_hidden_states, encoder_attention_mask = pack_sequences(
|
||||
lyric_hidden_states, timbre_embs_unpack, lyric_attention_mask, timbre_embs_mask
|
||||
)
|
||||
encoder_hidden_states, encoder_attention_mask = pack_sequences(
|
||||
encoder_hidden_states, text_hidden_states, encoder_attention_mask, text_attention_mask
|
||||
)
|
||||
return encoder_hidden_states, encoder_attention_mask
|
||||
@@ -1,901 +0,0 @@
|
||||
# Copyright 2025 The ACESTEO Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..core.attention.attention import attention_forward
|
||||
from ..core import gradient_checkpoint_forward
|
||||
|
||||
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers.models.qwen3.modeling_qwen3 import (
|
||||
Qwen3MLP,
|
||||
Qwen3RMSNorm,
|
||||
Qwen3RotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def create_4d_mask(
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
attention_mask: Optional[torch.Tensor] = None, # [Batch, Seq_Len]
|
||||
sliding_window: Optional[int] = None,
|
||||
is_sliding_window: bool = False,
|
||||
is_causal: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
General 4D Attention Mask generator compatible with CPU/Mac/SDPA and Eager mode.
|
||||
Supports use cases:
|
||||
1. Causal Full: is_causal=True, is_sliding_window=False (standard GPT)
|
||||
2. Causal Sliding: is_causal=True, is_sliding_window=True (Mistral/Qwen local window)
|
||||
3. Bidirectional Full: is_causal=False, is_sliding_window=False (BERT/Encoder)
|
||||
4. Bidirectional Sliding: is_causal=False, is_sliding_window=True (Longformer local)
|
||||
|
||||
Returns:
|
||||
[Batch, 1, Seq_Len, Seq_Len] additive mask (0.0 for keep, -inf for mask)
|
||||
"""
|
||||
# ------------------------------------------------------
|
||||
# 1. Construct basic geometry mask [Seq_Len, Seq_Len]
|
||||
# ------------------------------------------------------
|
||||
|
||||
# Build index matrices
|
||||
# i (Query): [0, 1, ..., L-1]
|
||||
# j (Key): [0, 1, ..., L-1]
|
||||
indices = torch.arange(seq_len, device=device)
|
||||
# diff = i - j
|
||||
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
|
||||
|
||||
# Initialize all True (all positions visible)
|
||||
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
|
||||
|
||||
# (A) Handle causality (Causal)
|
||||
if is_causal:
|
||||
# i >= j => diff >= 0
|
||||
valid_mask = valid_mask & (diff >= 0)
|
||||
|
||||
# (B) Handle sliding window
|
||||
if is_sliding_window and sliding_window is not None:
|
||||
if is_causal:
|
||||
# Causal sliding: only attend to past window steps
|
||||
# i - j <= window => diff <= window
|
||||
# (diff >= 0 already handled above)
|
||||
valid_mask = valid_mask & (diff <= sliding_window)
|
||||
else:
|
||||
# Bidirectional sliding: attend past and future window steps
|
||||
# |i - j| <= window => abs(diff) <= sliding_window
|
||||
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
|
||||
|
||||
# Expand dimensions to [1, 1, Seq_Len, Seq_Len] for broadcasting
|
||||
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# ------------------------------------------------------
|
||||
# 2. Apply padding mask (Key Masking)
|
||||
# ------------------------------------------------------
|
||||
if attention_mask is not None:
|
||||
# attention_mask shape: [Batch, Seq_Len] (1=valid, 0=padding)
|
||||
# We want to mask out invalid keys (columns)
|
||||
# Expand shape: [Batch, 1, 1, Seq_Len]
|
||||
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
|
||||
|
||||
# Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
|
||||
# Result shape: [B, 1, L, L]
|
||||
valid_mask = valid_mask & padding_mask_4d
|
||||
|
||||
# ------------------------------------------------------
|
||||
# 3. Convert to additive mask
|
||||
# ------------------------------------------------------
|
||||
# Get the minimal value for current dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
# Create result tensor filled with -inf by default
|
||||
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
|
||||
|
||||
# Set valid positions to 0.0
|
||||
mask_tensor.masked_fill_(valid_mask, 0.0)
|
||||
|
||||
return mask_tensor
|
||||
|
||||
|
||||
def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
|
||||
"""
|
||||
Pack two sequences by concatenating and sorting them based on mask values.
|
||||
|
||||
Args:
|
||||
hidden1: First hidden states tensor of shape [B, L1, D]
|
||||
hidden2: Second hidden states tensor of shape [B, L2, D]
|
||||
mask1: First mask tensor of shape [B, L1]
|
||||
mask2: Second mask tensor of shape [B, L2]
|
||||
|
||||
Returns:
|
||||
Tuple of (packed_hidden_states, new_mask) where:
|
||||
- packed_hidden_states: Packed hidden states with valid tokens (mask=1) first, shape [B, L1+L2, D]
|
||||
- new_mask: New mask tensor indicating valid positions, shape [B, L1+L2]
|
||||
"""
|
||||
# Step 1: Concatenate hidden states and masks along sequence dimension
|
||||
hidden_cat = torch.cat([hidden1, hidden2], dim=1) # [B, L, D]
|
||||
mask_cat = torch.cat([mask1, mask2], dim=1) # [B, L]
|
||||
|
||||
B, L, D = hidden_cat.shape
|
||||
|
||||
# Step 2: Sort indices so that mask values of 1 come before 0
|
||||
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) # [B, L]
|
||||
|
||||
# Step 3: Reorder hidden states using sorted indices
|
||||
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
|
||||
|
||||
# Step 4: Create new mask based on valid sequence lengths
|
||||
lengths = mask_cat.sum(dim=1) # [B]
|
||||
new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
|
||||
|
||||
return hidden_left, new_mask
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
"""
|
||||
Timestep embedding module for diffusion models.
|
||||
|
||||
Converts timestep values into high-dimensional embeddings using sinusoidal
|
||||
positional encoding, followed by MLP layers. Used for conditioning diffusion
|
||||
models on timestep information.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
scale: float = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
||||
self.act1 = nn.SiLU()
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.act2 = nn.SiLU()
|
||||
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
|
||||
self.scale = scale
|
||||
|
||||
def timestep_embedding(self, t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
Args:
|
||||
t: A 1-D tensor of N indices, one per batch element. These may be fractional.
|
||||
dim: The dimension of the output embeddings.
|
||||
max_period: Controls the minimum frequency of the embeddings.
|
||||
|
||||
Returns:
|
||||
An (N, D) tensor of positional embeddings.
|
||||
"""
|
||||
t = t * self.scale
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.in_channels)
|
||||
temb = self.linear_1(t_freq.to(t.dtype))
|
||||
temb = self.act1(temb)
|
||||
temb = self.linear_2(temb)
|
||||
timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1))
|
||||
return temb, timestep_proj
|
||||
|
||||
|
||||
class AceStepAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention module for AceStep model.
|
||||
|
||||
Implements the attention mechanism from 'Attention Is All You Need' paper,
|
||||
with support for both self-attention and cross-attention modes. Uses RMSNorm
|
||||
for query and key normalization, and supports sliding window attention for
|
||||
efficient long-sequence processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
rms_norm_eps: float,
|
||||
attention_bias: bool,
|
||||
attention_dropout: float,
|
||||
layer_types: list,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
layer_idx: int = 0,
|
||||
is_cross_attention: bool = False,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.num_key_value_groups = num_attention_heads // num_key_value_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.attention_dropout = attention_dropout
|
||||
if is_cross_attention:
|
||||
is_causal = False
|
||||
self.is_causal = is_causal
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
|
||||
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.attention_type = layer_types[layer_idx]
|
||||
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
# Project and normalize query states
|
||||
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
|
||||
# Determine if this is cross-attention (requires encoder_hidden_states)
|
||||
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
|
||||
|
||||
# Cross-attention path: attend to encoder hidden states
|
||||
if is_cross_attention:
|
||||
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
# After the first generated token, we can reuse all key/value states from cache
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
|
||||
# Conditions for calculating key and value states
|
||||
if not is_updated:
|
||||
# Compute and cache K/V for the first time
|
||||
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||
# Update cache: save all key/value states to cache for fast auto-regressive generation
|
||||
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
|
||||
# Set flag that this layer's cross-attention cache is updated
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
else:
|
||||
# Reuse cached key/value states for subsequent tokens
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
# No cache used, compute K/V directly
|
||||
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||
|
||||
# Self-attention path: attend to the same sequence
|
||||
else:
|
||||
# Project and normalize key/value states for self-attention
|
||||
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
# Apply rotary position embeddings (RoPE) if provided
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Update cache for auto-regressive generation
|
||||
if past_key_value is not None:
|
||||
# Sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
# GGA expansion: if num_key_value_heads < num_attention_heads
|
||||
if self.num_key_value_groups > 1:
|
||||
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||
|
||||
# Use DiffSynth unified attention
|
||||
# Tensors are already in (batch, heads, seq, dim) format -> "b n s d"
|
||||
attn_output = attention_forward(
|
||||
query_states, key_states, value_states,
|
||||
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
|
||||
attn_mask=attention_mask,
|
||||
)
|
||||
|
||||
attn_weights = None # attention_forward doesn't return weights
|
||||
|
||||
# Flatten and project output: (B, n_heads, seq, dim) -> (B, seq, n_heads*dim)
|
||||
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AceStepEncoderLayer(nn.Module):
|
||||
"""
|
||||
Encoder layer for AceStep model.
|
||||
|
||||
Consists of self-attention and MLP (feed-forward) sub-layers with residual connections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
intermediate_size: int = 6144,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
layer_types: list = None,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
layer_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = AceStepAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
layer_idx=layer_idx,
|
||||
is_cross_attention=False,
|
||||
is_causal=False,
|
||||
)
|
||||
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
# MLP (feed-forward) sub-layer
|
||||
self.mlp = Qwen3MLP(
|
||||
config=type('Config', (), {
|
||||
'hidden_size': hidden_size,
|
||||
'intermediate_size': intermediate_size,
|
||||
'hidden_act': 'silu',
|
||||
})()
|
||||
)
|
||||
self.attention_type = layer_types[layer_idx]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[
|
||||
torch.FloatTensor,
|
||||
Optional[tuple[torch.FloatTensor, torch.FloatTensor]],
|
||||
]:
|
||||
# Self-attention with residual connection
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
# Encoders don't use cache
|
||||
use_cache=False,
|
||||
past_key_value=None,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# MLP with residual connection
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class AceStepDiTLayer(nn.Module):
|
||||
"""
|
||||
DiT (Diffusion Transformer) layer for AceStep model.
|
||||
|
||||
Implements a transformer layer with three main components:
|
||||
1. Self-attention with adaptive layer norm (AdaLN)
|
||||
2. Cross-attention (optional) for conditioning on encoder outputs
|
||||
3. Feed-forward MLP with adaptive layer norm
|
||||
|
||||
Uses scale-shift modulation from timestep embeddings for adaptive normalization.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
intermediate_size: int,
|
||||
rms_norm_eps: float,
|
||||
attention_bias: bool,
|
||||
attention_dropout: float,
|
||||
layer_types: list,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
layer_idx: int = 0,
|
||||
use_cross_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.self_attn = AceStepAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
self.use_cross_attention = use_cross_attention
|
||||
if self.use_cross_attention:
|
||||
self.cross_attn_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.cross_attn = AceStepAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
layer_idx=layer_idx,
|
||||
is_cross_attention=True,
|
||||
)
|
||||
|
||||
self.mlp_norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.mlp = Qwen3MLP(
|
||||
config=type('Config', (), {
|
||||
'hidden_size': hidden_size,
|
||||
'intermediate_size': intermediate_size,
|
||||
'hidden_act': 'silu',
|
||||
})()
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, hidden_size) / hidden_size**0.5)
|
||||
self.attention_type = layer_types[layer_idx]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# Extract scale-shift parameters for adaptive layer norm from timestep embeddings
|
||||
# 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table.to(temb.device) + temb
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# Step 1: Self-attention with adaptive layer norm (AdaLN)
|
||||
# Apply adaptive normalization: norm(x) * (1 + scale) + shift
|
||||
norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output, self_attn_weights = self.self_attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
past_key_value=None,
|
||||
**kwargs,
|
||||
)
|
||||
# Apply gated residual connection: x = x + attn_output * gate
|
||||
hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# Step 2: Cross-attention (if enabled) for conditioning on encoder outputs
|
||||
if self.use_cross_attention:
|
||||
norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states)
|
||||
attn_output, cross_attn_weights = self.cross_attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
# Standard residual connection for cross-attention
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# Step 3: Feed-forward (MLP) with adaptive layer norm
|
||||
# Apply adaptive normalization for MLP: norm(x) * (1 + scale) + shift
|
||||
norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
|
||||
ff_output = self.mlp(norm_hidden_states)
|
||||
# Apply gated residual connection: x = x + mlp_output * gate
|
||||
hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights, cross_attn_weights)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
class Lambda(nn.Module):
|
||||
"""
|
||||
Wrapper module for arbitrary lambda functions.
|
||||
|
||||
Allows using lambda functions in nn.Sequential by wrapping them in a Module.
|
||||
Useful for simple transformations like transpose operations.
|
||||
"""
|
||||
def __init__(self, func):
|
||||
super().__init__()
|
||||
self.func = func
|
||||
|
||||
def forward(self, x):
|
||||
return self.func(x)
|
||||
|
||||
|
||||
class AceStepDiTModel(nn.Module):
|
||||
"""
|
||||
DiT (Diffusion Transformer) model for AceStep.
|
||||
|
||||
Main diffusion model that generates audio latents conditioned on text, lyrics,
|
||||
and timbre. Uses patch-based processing with transformer layers, timestep
|
||||
conditioning, and cross-attention to encoder outputs.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
num_hidden_layers: int = 24,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
layer_types: Optional[list] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = 128,
|
||||
use_sliding_window: bool = True,
|
||||
use_cache: bool = True,
|
||||
rope_theta: float = 1000000,
|
||||
max_position_embeddings: int = 32768,
|
||||
initializer_range: float = 0.02,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 192,
|
||||
audio_acoustic_hidden_dim: int = 64,
|
||||
encoder_hidden_size: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * (num_hidden_layers // 2))
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.use_cache = use_cache
|
||||
encoder_hidden_size = encoder_hidden_size or hidden_size
|
||||
|
||||
# Rotary position embeddings for transformer layers
|
||||
rope_config = type('RopeConfig', (), {
|
||||
'hidden_size': hidden_size,
|
||||
'num_attention_heads': num_attention_heads,
|
||||
'num_key_value_heads': num_key_value_heads,
|
||||
'head_dim': head_dim,
|
||||
'max_position_embeddings': max_position_embeddings,
|
||||
'rope_theta': rope_theta,
|
||||
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
|
||||
'rms_norm_eps': rms_norm_eps,
|
||||
'attention_bias': attention_bias,
|
||||
'attention_dropout': attention_dropout,
|
||||
'hidden_act': 'silu',
|
||||
'intermediate_size': intermediate_size,
|
||||
'layer_types': self.layer_types,
|
||||
'sliding_window': sliding_window,
|
||||
})()
|
||||
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
|
||||
|
||||
# Stack of DiT transformer layers
|
||||
self.layers = nn.ModuleList([
|
||||
AceStepDiTLayer(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
intermediate_size=intermediate_size,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=self.layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Input projection: patch embedding using 1D convolution
|
||||
self.proj_in = nn.Sequential(
|
||||
Lambda(lambda x: x.transpose(1, 2)),
|
||||
nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=hidden_size,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=0,
|
||||
),
|
||||
Lambda(lambda x: x.transpose(1, 2)),
|
||||
)
|
||||
|
||||
# Timestep embeddings for diffusion conditioning
|
||||
self.time_embed = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
|
||||
self.time_embed_r = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size)
|
||||
|
||||
# Project encoder hidden states to model dimension
|
||||
self.condition_embedder = nn.Linear(encoder_hidden_size, hidden_size, bias=True)
|
||||
|
||||
# Output normalization and projection
|
||||
self.norm_out = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.proj_out = nn.Sequential(
|
||||
Lambda(lambda x: x.transpose(1, 2)),
|
||||
nn.ConvTranspose1d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=audio_acoustic_hidden_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
padding=0,
|
||||
),
|
||||
Lambda(lambda x: x.transpose(1, 2)),
|
||||
)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, hidden_size) / hidden_size**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
timestep_r: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
context_latents: torch.Tensor,
|
||||
use_cache: Optional[bool] = False,
|
||||
past_key_values: Optional[EncoderDecoderCache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
return_hidden_states: int = None,
|
||||
custom_layers_config: Optional[dict] = None,
|
||||
enable_early_exit: bool = False,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
):
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.use_cache
|
||||
|
||||
# Disable cache during training or when gradient checkpointing is enabled
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
if self.training:
|
||||
use_cache = False
|
||||
|
||||
# Initialize cache if needed (only during inference for auto-regressive generation)
|
||||
if not self.training and use_cache and past_key_values is None:
|
||||
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||
|
||||
# Compute timestep embeddings for diffusion conditioning
|
||||
# Two embeddings: one for timestep t, one for timestep difference (t - r)
|
||||
temb_t, timestep_proj_t = self.time_embed(timestep)
|
||||
temb_r, timestep_proj_r = self.time_embed_r(timestep - timestep_r)
|
||||
# Combine embeddings
|
||||
temb = temb_t + temb_r
|
||||
timestep_proj = timestep_proj_t + timestep_proj_r
|
||||
|
||||
# Concatenate context latents (source latents + chunk masks) with hidden states
|
||||
hidden_states = torch.cat([context_latents, hidden_states], dim=-1)
|
||||
# Record original sequence length for later restoration after padding
|
||||
original_seq_len = hidden_states.shape[1]
|
||||
# Apply padding if sequence length is not divisible by patch_size
|
||||
# This ensures proper patch extraction
|
||||
pad_length = 0
|
||||
if hidden_states.shape[1] % self.patch_size != 0:
|
||||
pad_length = self.patch_size - (hidden_states.shape[1] % self.patch_size)
|
||||
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_length), mode='constant', value=0)
|
||||
|
||||
# Project input to patches and project encoder states
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
|
||||
|
||||
# Cache positions
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
|
||||
)
|
||||
|
||||
# Position IDs
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
seq_len = hidden_states.shape[1]
|
||||
encoder_seq_len = encoder_hidden_states.shape[1]
|
||||
dtype = hidden_states.dtype
|
||||
device = hidden_states.device
|
||||
|
||||
# Initialize Mask variables
|
||||
full_attn_mask = None
|
||||
sliding_attn_mask = None
|
||||
encoder_attn_mask = None
|
||||
decoder_attn_mask = None
|
||||
# Target library discards the passed-in attention_mask for 4D mask
|
||||
# construction (line 1384: attention_mask = None)
|
||||
attention_mask = None
|
||||
|
||||
# 1. Full Attention (Bidirectional, Global)
|
||||
full_attn_mask = create_4d_mask(
|
||||
seq_len=seq_len,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window=None,
|
||||
is_sliding_window=False,
|
||||
is_causal=False
|
||||
)
|
||||
max_len = max(seq_len, encoder_seq_len)
|
||||
|
||||
encoder_attn_mask = create_4d_mask(
|
||||
seq_len=max_len,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window=None,
|
||||
is_sliding_window=False,
|
||||
is_causal=False
|
||||
)
|
||||
encoder_attn_mask = encoder_attn_mask[:, :, :seq_len, :encoder_seq_len]
|
||||
|
||||
# 2. Sliding Attention (Bidirectional, Local)
|
||||
if self.use_sliding_window:
|
||||
sliding_attn_mask = create_4d_mask(
|
||||
seq_len=seq_len,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
attention_mask=attention_mask,
|
||||
sliding_window=self.sliding_window,
|
||||
is_sliding_window=True,
|
||||
is_causal=False
|
||||
)
|
||||
|
||||
# Build mask mapping
|
||||
self_attn_mask_mapping = {
|
||||
"full_attention": full_attn_mask,
|
||||
"sliding_attention": sliding_attn_mask,
|
||||
"encoder_attention_mask": encoder_attn_mask,
|
||||
}
|
||||
|
||||
# Create position embeddings to be shared across all decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
all_cross_attentions = () if output_attentions else None
|
||||
|
||||
# Handle early exit for custom layer configurations
|
||||
max_needed_layer = float('inf')
|
||||
if custom_layers_config is not None and enable_early_exit:
|
||||
max_needed_layer = max(custom_layers_config.keys())
|
||||
output_attentions = True
|
||||
if all_cross_attentions is None:
|
||||
all_cross_attentions = ()
|
||||
|
||||
# Process through transformer layers
|
||||
for index_block, layer_module in enumerate(self.layers):
|
||||
# Early exit optimization
|
||||
if index_block > max_needed_layer:
|
||||
break
|
||||
|
||||
# Prepare layer arguments
|
||||
layer_args = (
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
timestep_proj,
|
||||
self_attn_mask_mapping[layer_module.attention_type],
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
encoder_hidden_states,
|
||||
self_attn_mask_mapping["encoder_attention_mask"],
|
||||
)
|
||||
layer_kwargs = flash_attn_kwargs
|
||||
|
||||
# Use gradient checkpointing if enabled
|
||||
layer_outputs = gradient_checkpoint_forward(
|
||||
layer_module,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
*layer_args,
|
||||
**layer_kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions and self.layers[index_block].use_cross_attention:
|
||||
# layer_outputs structure: (hidden_states, self_attn_weights, cross_attn_weights)
|
||||
if len(layer_outputs) >= 3:
|
||||
all_cross_attentions += (layer_outputs[2],)
|
||||
|
||||
if return_hidden_states:
|
||||
return hidden_states
|
||||
|
||||
# Extract scale-shift parameters for adaptive output normalization
|
||||
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
shift = shift.to(hidden_states.device)
|
||||
scale = scale.to(hidden_states.device)
|
||||
|
||||
# Apply adaptive layer norm: norm(x) * (1 + scale) + shift
|
||||
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).type_as(hidden_states)
|
||||
# Project output: de-patchify back to original sequence format
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# Crop back to original sequence length to ensure exact length match (remove padding)
|
||||
hidden_states = hidden_states[:, :original_seq_len, :]
|
||||
|
||||
outputs = (hidden_states, past_key_values)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (all_cross_attentions,)
|
||||
return outputs
|
||||
@@ -1,53 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class AceStepTextEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
from transformers import Qwen3Config, Qwen3Model
|
||||
|
||||
config = Qwen3Config(
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
bos_token_id=151643,
|
||||
dtype="bfloat16",
|
||||
eos_token_id=151643,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
hidden_size=1024,
|
||||
initializer_range=0.02,
|
||||
intermediate_size=3072,
|
||||
layer_types=["full_attention"] * 28,
|
||||
max_position_embeddings=32768,
|
||||
max_window_layers=28,
|
||||
model_type="qwen3",
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=28,
|
||||
num_key_value_heads=8,
|
||||
pad_token_id=151643,
|
||||
rms_norm_eps=1e-06,
|
||||
rope_scaling=None,
|
||||
rope_theta=1000000,
|
||||
sliding_window=None,
|
||||
tie_word_embeddings=True,
|
||||
use_cache=True,
|
||||
use_sliding_window=False,
|
||||
vocab_size=151669,
|
||||
)
|
||||
|
||||
self.model = Qwen3Model(config)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
attention_mask: torch.Tensor,
|
||||
):
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
return outputs.last_hidden_state
|
||||
@@ -1,732 +0,0 @@
|
||||
# Copyright 2025 The ACESTEO Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ACE-Step Audio Tokenizer — VAE latent discretization pathway.
|
||||
|
||||
Contains:
|
||||
- AceStepAudioTokenizer: continuous VAE latent → discrete FSQ tokens
|
||||
- AudioTokenDetokenizer: discrete tokens → continuous VAE-latent-shaped features
|
||||
|
||||
Only used in cover song mode (is_covers=True). Bypassed in text-to-music.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
from transformers.processing_utils import Unpack
|
||||
from transformers.utils import can_return_tuple, logging
|
||||
from transformers.models.qwen3.modeling_qwen3 import (
|
||||
Qwen3MLP,
|
||||
Qwen3RMSNorm,
|
||||
Qwen3RotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
from vector_quantize_pytorch import ResidualFSQ
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def create_4d_mask(
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
is_sliding_window: bool = False,
|
||||
is_causal: bool = True,
|
||||
) -> torch.Tensor:
|
||||
indices = torch.arange(seq_len, device=device)
|
||||
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
|
||||
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
|
||||
if is_causal:
|
||||
valid_mask = valid_mask & (diff >= 0)
|
||||
if is_sliding_window and sliding_window is not None:
|
||||
if is_causal:
|
||||
valid_mask = valid_mask & (diff <= sliding_window)
|
||||
else:
|
||||
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
|
||||
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
|
||||
if attention_mask is not None:
|
||||
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
|
||||
valid_mask = valid_mask & padding_mask_4d
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
|
||||
mask_tensor.masked_fill_(valid_mask, 0.0)
|
||||
return mask_tensor
|
||||
|
||||
|
||||
class Lambda(nn.Module):
|
||||
def __init__(self, func):
|
||||
super().__init__()
|
||||
self.func = func
|
||||
|
||||
def forward(self, x):
|
||||
return self.func(x)
|
||||
|
||||
|
||||
class AceStepAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
rms_norm_eps: float,
|
||||
attention_bias: bool,
|
||||
attention_dropout: float,
|
||||
layer_types: list,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
layer_idx: int = 0,
|
||||
is_cross_attention: bool = False,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.num_key_value_groups = num_attention_heads // num_key_value_heads
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.attention_dropout = attention_dropout
|
||||
if is_cross_attention:
|
||||
is_causal = False
|
||||
self.is_causal = is_causal
|
||||
self.is_cross_attention = is_cross_attention
|
||||
|
||||
self.q_proj = nn.Linear(hidden_size, num_attention_heads * self.head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * self.head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(num_attention_heads * self.head_dim, hidden_size, bias=attention_bias)
|
||||
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.attention_type = layer_types[layer_idx]
|
||||
self.sliding_window = sliding_window if layer_types[layer_idx] == "sliding_attention" else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
|
||||
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
|
||||
|
||||
if is_cross_attention:
|
||||
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
|
||||
if past_key_value is not None:
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
if not is_updated:
|
||||
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
else:
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
|
||||
|
||||
else:
|
||||
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
if self.num_key_value_groups > 1:
|
||||
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_key_value_groups, -1, -1).flatten(1, 2)
|
||||
|
||||
attn_output = attention_forward(
|
||||
query_states, key_states, value_states,
|
||||
q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d",
|
||||
attn_mask=attention_mask,
|
||||
)
|
||||
attn_weights = None
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).flatten(2, 3).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class AceStepEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
rms_norm_eps: float,
|
||||
attention_bias: bool,
|
||||
attention_dropout: float,
|
||||
layer_types: list,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
layer_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = AceStepAttention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
layer_idx=layer_idx,
|
||||
is_cross_attention=False,
|
||||
is_causal=False,
|
||||
)
|
||||
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
mlp_config = type('Config', (), {
|
||||
'hidden_size': hidden_size,
|
||||
'intermediate_size': intermediate_size,
|
||||
'hidden_act': 'silu',
|
||||
})()
|
||||
self.mlp = Qwen3MLP(mlp_config)
|
||||
self.attention_type = layer_types[layer_idx]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=False,
|
||||
past_key_value=None,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
return outputs
|
||||
|
||||
|
||||
class AttentionPooler(nn.Module):
|
||||
"""Pools every pool_window_size frames into 1 representation via transformer + CLS token."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
layer_types: Optional[list] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = 128,
|
||||
use_sliding_window: bool = True,
|
||||
rope_theta: float = 1000000,
|
||||
max_position_embeddings: int = 32768,
|
||||
initializer_range: float = 0.02,
|
||||
num_attention_pooler_hidden_layers: int = 2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Default matches target library config (24 alternating entries).
|
||||
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
|
||||
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||
|
||||
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
|
||||
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
# Slice layer_types to our own layer count
|
||||
pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
|
||||
rope_config = type('RopeConfig', (), {
|
||||
'hidden_size': hidden_size,
|
||||
'num_attention_heads': num_attention_heads,
|
||||
'num_key_value_heads': num_key_value_heads,
|
||||
'head_dim': head_dim,
|
||||
'max_position_embeddings': max_position_embeddings,
|
||||
'rope_theta': rope_theta,
|
||||
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
|
||||
'rms_norm_eps': rms_norm_eps,
|
||||
'attention_bias': attention_bias,
|
||||
'attention_dropout': attention_dropout,
|
||||
'hidden_act': 'silu',
|
||||
'intermediate_size': intermediate_size,
|
||||
'layer_types': pooler_layer_types,
|
||||
'sliding_window': sliding_window,
|
||||
'_attn_implementation': self._attn_implementation,
|
||||
})()
|
||||
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
|
||||
self.gradient_checkpointing = False
|
||||
self.special_token = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
|
||||
self.layers = nn.ModuleList([
|
||||
AceStepEncoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=pooler_layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
for layer_idx in range(num_attention_pooler_hidden_layers)
|
||||
])
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> torch.Tensor:
|
||||
B, T, P, D = x.shape
|
||||
x = self.embed_tokens(x)
|
||||
special_tokens = self.special_token.expand(B, T, 1, -1).to(x.device)
|
||||
x = torch.cat([special_tokens, x], dim=2)
|
||||
x = rearrange(x, "b t p c -> (b t) p c")
|
||||
|
||||
cache_position = torch.arange(0, x.shape[1], device=x.device)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
hidden_states = x
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
seq_len = x.shape[1]
|
||||
dtype = x.dtype
|
||||
device = x.device
|
||||
|
||||
full_attn_mask = create_4d_mask(
|
||||
seq_len=seq_len, dtype=dtype, device=device,
|
||||
attention_mask=attention_mask, sliding_window=None,
|
||||
is_sliding_window=False, is_causal=False
|
||||
)
|
||||
sliding_attn_mask = None
|
||||
if self.use_sliding_window:
|
||||
sliding_attn_mask = create_4d_mask(
|
||||
seq_len=seq_len, dtype=dtype, device=device,
|
||||
attention_mask=attention_mask, sliding_window=self.sliding_window,
|
||||
is_sliding_window=True, is_causal=False
|
||||
)
|
||||
|
||||
self_attn_mask_mapping = {
|
||||
"full_attention": full_attn_mask,
|
||||
"sliding_attention": sliding_attn_mask,
|
||||
}
|
||||
|
||||
for layer_module in self.layers:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, position_embeddings,
|
||||
attention_mask=self_attn_mask_mapping[layer_module.attention_type],
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
cls_output = hidden_states[:, 0, :]
|
||||
return rearrange(cls_output, "(b t) c -> b t c", b=B)
|
||||
|
||||
|
||||
class AceStepAudioTokenizer(nn.Module):
|
||||
"""Converts continuous acoustic features (VAE latents) into discrete quantized tokens.
|
||||
|
||||
Input: [B, T, 64] (VAE latent dim)
|
||||
Output: quantized [B, T/5, 2048], indices [B, T/5, 1]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
layer_types: Optional[list] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = 128,
|
||||
use_sliding_window: bool = True,
|
||||
rope_theta: float = 1000000,
|
||||
max_position_embeddings: int = 32768,
|
||||
initializer_range: float = 0.02,
|
||||
audio_acoustic_hidden_dim: int = 64,
|
||||
pool_window_size: int = 5,
|
||||
fsq_dim: int = 2048,
|
||||
fsq_input_levels: list = None,
|
||||
fsq_input_num_quantizers: int = 1,
|
||||
num_attention_pooler_hidden_layers: int = 2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Default matches target library config (24 alternating entries).
|
||||
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
|
||||
self.pool_window_size = pool_window_size
|
||||
self.fsq_dim = fsq_dim
|
||||
self.fsq_input_levels = fsq_input_levels or [8, 8, 8, 5, 5, 5]
|
||||
self.fsq_input_num_quantizers = fsq_input_num_quantizers
|
||||
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
|
||||
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||
|
||||
self.audio_acoustic_proj = nn.Linear(audio_acoustic_hidden_dim, hidden_size)
|
||||
# Slice layer_types for the attention pooler
|
||||
pooler_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
|
||||
self.attention_pooler = AttentionPooler(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=pooler_layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
use_sliding_window=use_sliding_window,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
|
||||
)
|
||||
self.quantizer = ResidualFSQ(
|
||||
dim=self.fsq_dim,
|
||||
levels=self.fsq_input_levels,
|
||||
num_quantizers=self.fsq_input_num_quantizers,
|
||||
force_quantization_f32=False, # avoid autocast bug in vector_quantize_pytorch
|
||||
)
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.audio_acoustic_proj(hidden_states)
|
||||
hidden_states = self.attention_pooler(hidden_states)
|
||||
quantized, indices = self.quantizer(hidden_states)
|
||||
return quantized, indices
|
||||
|
||||
def tokenize(self, x):
|
||||
"""Convenience: takes [B, T, 64], rearranges to patches, runs forward."""
|
||||
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=self.pool_window_size)
|
||||
return self.forward(x)
|
||||
|
||||
|
||||
class AudioTokenDetokenizer(nn.Module):
|
||||
"""Converts quantized audio tokens back to continuous acoustic representations.
|
||||
|
||||
Input: [B, T/5, hidden_size] (quantized vectors)
|
||||
Output: [B, T, 64] (VAE-latent-shaped continuous features)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
layer_types: Optional[list] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = 128,
|
||||
use_sliding_window: bool = True,
|
||||
rope_theta: float = 1000000,
|
||||
max_position_embeddings: int = 32768,
|
||||
initializer_range: float = 0.02,
|
||||
pool_window_size: int = 5,
|
||||
audio_acoustic_hidden_dim: int = 64,
|
||||
num_attention_pooler_hidden_layers: int = 2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Default matches target library config (24 alternating entries).
|
||||
self.layer_types = layer_types or (["sliding_attention", "full_attention"] * 12)
|
||||
self.head_dim = head_dim or hidden_size // num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.pool_window_size = pool_window_size
|
||||
self.audio_acoustic_hidden_dim = audio_acoustic_hidden_dim
|
||||
self.num_attention_pooler_hidden_layers = num_attention_pooler_hidden_layers
|
||||
self._attn_implementation = kwargs.get("_attn_implementation", "sdpa")
|
||||
|
||||
self.embed_tokens = nn.Linear(hidden_size, hidden_size)
|
||||
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
# Slice layer_types to our own layer count (use num_audio_decoder_hidden_layers)
|
||||
detok_layer_types = self.layer_types[:num_attention_pooler_hidden_layers]
|
||||
rope_config = type('RopeConfig', (), {
|
||||
'hidden_size': hidden_size,
|
||||
'num_attention_heads': num_attention_heads,
|
||||
'num_key_value_heads': num_key_value_heads,
|
||||
'head_dim': head_dim,
|
||||
'max_position_embeddings': max_position_embeddings,
|
||||
'rope_theta': rope_theta,
|
||||
'rope_parameters': {'rope_type': 'default', 'rope_theta': rope_theta},
|
||||
'rms_norm_eps': rms_norm_eps,
|
||||
'attention_bias': attention_bias,
|
||||
'attention_dropout': attention_dropout,
|
||||
'hidden_act': 'silu',
|
||||
'intermediate_size': intermediate_size,
|
||||
'layer_types': detok_layer_types,
|
||||
'sliding_window': sliding_window,
|
||||
'_attn_implementation': self._attn_implementation,
|
||||
})()
|
||||
self.rotary_emb = Qwen3RotaryEmbedding(rope_config)
|
||||
self.gradient_checkpointing = False
|
||||
self.special_tokens = nn.Parameter(torch.randn(1, pool_window_size, hidden_size) * 0.02)
|
||||
self.layers = nn.ModuleList([
|
||||
AceStepEncoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=detok_layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
for layer_idx in range(num_attention_pooler_hidden_layers)
|
||||
])
|
||||
self.proj_out = nn.Linear(hidden_size, audio_acoustic_hidden_dim)
|
||||
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> torch.Tensor:
|
||||
B, T, D = x.shape
|
||||
x = self.embed_tokens(x)
|
||||
x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
|
||||
special_tokens = self.special_tokens.expand(B, T, -1, -1)
|
||||
x = x + special_tokens.to(x.device)
|
||||
x = rearrange(x, "b t p c -> (b t) p c")
|
||||
|
||||
cache_position = torch.arange(0, x.shape[1], device=x.device)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
hidden_states = x
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
seq_len = x.shape[1]
|
||||
dtype = x.dtype
|
||||
device = x.device
|
||||
|
||||
full_attn_mask = create_4d_mask(
|
||||
seq_len=seq_len, dtype=dtype, device=device,
|
||||
attention_mask=attention_mask, sliding_window=None,
|
||||
is_sliding_window=False, is_causal=False
|
||||
)
|
||||
sliding_attn_mask = None
|
||||
if self.use_sliding_window:
|
||||
sliding_attn_mask = create_4d_mask(
|
||||
seq_len=seq_len, dtype=dtype, device=device,
|
||||
attention_mask=attention_mask, sliding_window=self.sliding_window,
|
||||
is_sliding_window=True, is_causal=False
|
||||
)
|
||||
|
||||
self_attn_mask_mapping = {
|
||||
"full_attention": full_attn_mask,
|
||||
"sliding_attention": sliding_attn_mask,
|
||||
}
|
||||
|
||||
for layer_module in self.layers:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states, position_embeddings,
|
||||
attention_mask=self_attn_mask_mapping[layer_module.attention_type],
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return rearrange(hidden_states, "(b t) p c -> b (t p) c", b=B, p=self.pool_window_size)
|
||||
|
||||
|
||||
class AceStepTokenizer(nn.Module):
|
||||
"""Container for AceStepAudioTokenizer + AudioTokenDetokenizer.
|
||||
|
||||
Provides encode/decode convenience methods for VAE latent discretization.
|
||||
Used in cover song mode to convert source audio latents to discrete tokens
|
||||
and back to continuous conditioning hints.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 2048,
|
||||
intermediate_size: int = 6144,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int = 8,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
attention_dropout: float = 0.0,
|
||||
layer_types: Optional[list] = None,
|
||||
head_dim: Optional[int] = None,
|
||||
sliding_window: Optional[int] = 128,
|
||||
use_sliding_window: bool = True,
|
||||
rope_theta: float = 1000000,
|
||||
max_position_embeddings: int = 32768,
|
||||
initializer_range: float = 0.02,
|
||||
audio_acoustic_hidden_dim: int = 64,
|
||||
pool_window_size: int = 5,
|
||||
fsq_dim: int = 2048,
|
||||
fsq_input_levels: list = None,
|
||||
fsq_input_num_quantizers: int = 1,
|
||||
num_attention_pooler_hidden_layers: int = 2,
|
||||
num_audio_decoder_hidden_layers: int = 24,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
# Default layer_types matches target library config (24 alternating entries).
|
||||
# Sub-modules (pooler/detokenizer) slice first N entries for their own layer count.
|
||||
if layer_types is None:
|
||||
layer_types = ["sliding_attention", "full_attention"] * 12
|
||||
self.tokenizer = AceStepAudioTokenizer(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
use_sliding_window=use_sliding_window,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
audio_acoustic_hidden_dim=audio_acoustic_hidden_dim,
|
||||
pool_window_size=pool_window_size,
|
||||
fsq_dim=fsq_dim,
|
||||
fsq_input_levels=fsq_input_levels,
|
||||
fsq_input_num_quantizers=fsq_input_num_quantizers,
|
||||
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
|
||||
**kwargs,
|
||||
)
|
||||
self.detokenizer = AudioTokenDetokenizer(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
attention_dropout=attention_dropout,
|
||||
layer_types=layer_types,
|
||||
head_dim=head_dim,
|
||||
sliding_window=sliding_window,
|
||||
use_sliding_window=use_sliding_window,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
initializer_range=initializer_range,
|
||||
pool_window_size=pool_window_size,
|
||||
audio_acoustic_hidden_dim=audio_acoustic_hidden_dim,
|
||||
num_attention_pooler_hidden_layers=num_attention_pooler_hidden_layers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def encode(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""VAE latent [B, T, 64] → discrete tokens."""
|
||||
return self.tokenizer(hidden_states)
|
||||
|
||||
def decode(self, quantized: torch.Tensor) -> torch.Tensor:
|
||||
"""Discrete tokens [B, T/5, hidden_size] → continuous [B, T, 64]."""
|
||||
return self.detokenizer(quantized)
|
||||
|
||||
def tokenize(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convenience: [B, T, 64] → quantized + indices via patch rearrangement."""
|
||||
return self.tokenizer.tokenize(x)
|
||||
@@ -1,287 +0,0 @@
|
||||
# Copyright 2025 The ACESTEO Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ACE-Step Audio VAE (AutoencoderOobleck CNN architecture).
|
||||
|
||||
This is a CNN-based VAE for audio waveform encoding/decoding.
|
||||
It uses weight-normalized convolutions and Snake1d activations.
|
||||
Does NOT depend on diffusers — pure nn.Module implementation.
|
||||
"""
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
"""Snake activation: x + 1/(beta+eps) * sin(alpha*x)^2."""
|
||||
|
||||
def __init__(self, hidden_dim: int, logscale: bool = True):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
|
||||
self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
|
||||
self.logscale = logscale
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
shape = hidden_states.shape
|
||||
alpha = torch.exp(self.alpha) if self.logscale else self.alpha
|
||||
beta = torch.exp(self.beta) if self.logscale else self.beta
|
||||
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
|
||||
hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
|
||||
return hidden_states.reshape(shape)
|
||||
|
||||
|
||||
class OobleckResidualUnit(nn.Module):
|
||||
"""Residual unit: Snake1d → Conv1d(dilated) → Snake1d → Conv1d(1×1) + skip."""
|
||||
|
||||
def __init__(self, dimension: int = 16, dilation: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.snake1 = Snake1d(dimension)
|
||||
self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
|
||||
self.snake2 = Snake1d(dimension)
|
||||
self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
output = self.conv1(self.snake1(hidden_state))
|
||||
output = self.conv2(self.snake2(output))
|
||||
padding = (hidden_state.shape[-1] - output.shape[-1]) // 2
|
||||
if padding > 0:
|
||||
hidden_state = hidden_state[..., padding:-padding]
|
||||
return hidden_state + output
|
||||
|
||||
|
||||
class OobleckEncoderBlock(nn.Module):
|
||||
"""Encoder block: 3 residual units + downsampling conv."""
|
||||
|
||||
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
|
||||
super().__init__()
|
||||
self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
|
||||
self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
|
||||
self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
|
||||
self.snake1 = Snake1d(input_dim)
|
||||
self.conv1 = weight_norm(
|
||||
nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
|
||||
)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = self.res_unit1(hidden_state)
|
||||
hidden_state = self.res_unit2(hidden_state)
|
||||
hidden_state = self.snake1(self.res_unit3(hidden_state))
|
||||
return self.conv1(hidden_state)
|
||||
|
||||
|
||||
class OobleckDecoderBlock(nn.Module):
|
||||
"""Decoder block: upsampling conv + 3 residual units."""
|
||||
|
||||
def __init__(self, input_dim: int, output_dim: int, stride: int = 1):
|
||||
super().__init__()
|
||||
self.snake1 = Snake1d(input_dim)
|
||||
self.conv_t1 = weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2),
|
||||
)
|
||||
)
|
||||
self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
|
||||
self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
|
||||
self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = self.snake1(hidden_state)
|
||||
hidden_state = self.conv_t1(hidden_state)
|
||||
hidden_state = self.res_unit1(hidden_state)
|
||||
hidden_state = self.res_unit2(hidden_state)
|
||||
return self.res_unit3(hidden_state)
|
||||
|
||||
|
||||
class OobleckEncoder(nn.Module):
|
||||
"""Full encoder: audio → latent representation [B, encoder_hidden_size, T'].
|
||||
|
||||
conv1 → [blocks] → snake1 → conv2
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_hidden_size: int = 128,
|
||||
audio_channels: int = 2,
|
||||
downsampling_ratios: list = None,
|
||||
channel_multiples: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10]
|
||||
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
|
||||
channel_multiples = [1] + channel_multiples
|
||||
|
||||
self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
|
||||
|
||||
self.block = nn.ModuleList()
|
||||
for stride_index, stride in enumerate(downsampling_ratios):
|
||||
self.block.append(
|
||||
OobleckEncoderBlock(
|
||||
input_dim=encoder_hidden_size * channel_multiples[stride_index],
|
||||
output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
|
||||
stride=stride,
|
||||
)
|
||||
)
|
||||
|
||||
d_model = encoder_hidden_size * channel_multiples[-1]
|
||||
self.snake1 = Snake1d(d_model)
|
||||
self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = self.conv1(hidden_state)
|
||||
for block in self.block:
|
||||
hidden_state = block(hidden_state)
|
||||
hidden_state = self.snake1(hidden_state)
|
||||
return self.conv2(hidden_state)
|
||||
|
||||
|
||||
class OobleckDecoder(nn.Module):
|
||||
"""Full decoder: latent → audio waveform [B, audio_channels, T].
|
||||
|
||||
conv1 → [blocks] → snake1 → conv2(no bias)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 128,
|
||||
input_channels: int = 64,
|
||||
audio_channels: int = 2,
|
||||
upsampling_ratios: list = None,
|
||||
channel_multiples: list = None,
|
||||
):
|
||||
super().__init__()
|
||||
upsampling_ratios = upsampling_ratios or [10, 6, 4, 4, 2]
|
||||
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
|
||||
channel_multiples = [1] + channel_multiples
|
||||
|
||||
self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
|
||||
|
||||
self.block = nn.ModuleList()
|
||||
for stride_index, stride in enumerate(upsampling_ratios):
|
||||
self.block.append(
|
||||
OobleckDecoderBlock(
|
||||
input_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index],
|
||||
output_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index - 1],
|
||||
stride=stride,
|
||||
)
|
||||
)
|
||||
|
||||
self.snake1 = Snake1d(channels)
|
||||
# conv2 has no bias (matches checkpoint: only weight_g/weight_v, no bias key)
|
||||
self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
hidden_state = self.conv1(hidden_state)
|
||||
for block in self.block:
|
||||
hidden_state = block(hidden_state)
|
||||
hidden_state = self.snake1(hidden_state)
|
||||
return self.conv2(hidden_state)
|
||||
|
||||
|
||||
class OobleckDiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.scale = parameters.chunk(2, dim=1)
|
||||
self.std = nn.functional.softplus(self.scale) + 1e-4
|
||||
self.var = self.std * self.std
|
||||
self.logvar = torch.log(self.var)
|
||||
self.deterministic = deterministic
|
||||
|
||||
def sample(self, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
# make sure sample is on the same device as the parameters and has same dtype
|
||||
sample = torch.randn(
|
||||
self.mean.shape,
|
||||
generator=generator,
|
||||
device=self.parameters.device,
|
||||
dtype=self.parameters.dtype,
|
||||
)
|
||||
x = self.mean + self.std * sample
|
||||
return x
|
||||
|
||||
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.0])
|
||||
else:
|
||||
if other is None:
|
||||
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
|
||||
else:
|
||||
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
|
||||
var_ratio = self.var / other.var
|
||||
logvar_diff = self.logvar - other.logvar
|
||||
|
||||
kl = normalized_diff + var_ratio + logvar_diff - 1
|
||||
|
||||
kl = kl.sum(1).mean()
|
||||
return kl
|
||||
|
||||
|
||||
class AceStepVAE(nn.Module):
|
||||
"""Audio VAE for ACE-Step (AutoencoderOobleck architecture).
|
||||
|
||||
Encodes audio waveform → latent, decodes latent → audio waveform.
|
||||
Uses Snake1d activations and weight-normalized convolutions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_hidden_size: int = 128,
|
||||
downsampling_ratios: list = None,
|
||||
channel_multiples: list = None,
|
||||
decoder_channels: int = 128,
|
||||
decoder_input_channels: int = 64,
|
||||
audio_channels: int = 2,
|
||||
sampling_rate: int = 48000,
|
||||
):
|
||||
super().__init__()
|
||||
downsampling_ratios = downsampling_ratios or [2, 4, 4, 6, 10]
|
||||
channel_multiples = channel_multiples or [1, 2, 4, 8, 16]
|
||||
upsampling_ratios = downsampling_ratios[::-1]
|
||||
|
||||
self.encoder = OobleckEncoder(
|
||||
encoder_hidden_size=encoder_hidden_size,
|
||||
audio_channels=audio_channels,
|
||||
downsampling_ratios=downsampling_ratios,
|
||||
channel_multiples=channel_multiples,
|
||||
)
|
||||
self.decoder = OobleckDecoder(
|
||||
channels=decoder_channels,
|
||||
input_channels=decoder_input_channels,
|
||||
audio_channels=audio_channels,
|
||||
upsampling_ratios=upsampling_ratios,
|
||||
channel_multiples=channel_multiples,
|
||||
)
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Audio waveform [B, audio_channels, T] → latent [B, decoder_input_channels, T']."""
|
||||
h = self.encoder(x)
|
||||
output = OobleckDiagonalGaussianDistribution(h).sample()
|
||||
return output
|
||||
|
||||
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
"""Latent [B, decoder_input_channels, T] → audio waveform [B, audio_channels, T']."""
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
"""Full round-trip: encode → decode."""
|
||||
z = self.encode(sample)
|
||||
return self.decode(z)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization from all conv layers (for export/inference)."""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
|
||||
remove_weight_norm(module)
|
||||
@@ -1,362 +0,0 @@
|
||||
"""
|
||||
Ernie-Image DiT for DiffSynth-Studio.
|
||||
|
||||
Refactored from diffusers ErnieImageTransformer2DModel to use DiffSynth core modules.
|
||||
Default parameters from actual checkpoint config.json (PaddlePaddle/ERNIE-Image transformer).
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
from .flux2_dit import Timesteps, TimestepEmbedding
|
||||
|
||||
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta ** scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
return out.float()
|
||||
|
||||
|
||||
class ErnieImageEmbedND3(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = list(axes_dim)
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
|
||||
emb = emb.unsqueeze(2)
|
||||
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1)
|
||||
|
||||
|
||||
class ErnieImagePatchEmbedDynamic(nn.Module):
|
||||
def __init__(self, in_channels: int, embed_dim: int, patch_size: int):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
batch_size, dim, height, width = x.shape
|
||||
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
class ErnieImageSingleStreamAttnProcessor:
|
||||
def __call__(
|
||||
self,
|
||||
attn: "ErnieImageAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
rot_dim = freqs_cis.shape[-1]
|
||||
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
|
||||
cos_ = torch.cos(freqs_cis).to(x.dtype)
|
||||
sin_ = torch.sin(freqs_cis).to(x.dtype)
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x_rotated = torch.cat((-x2, x1), dim=-1)
|
||||
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
|
||||
|
||||
if freqs_cis is not None:
|
||||
query = apply_rotary_emb(query, freqs_cis)
|
||||
key = apply_rotary_emb(key, freqs_cis)
|
||||
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
hidden_states = attention_forward(
|
||||
query, key, value,
|
||||
q_pattern="b s n d",
|
||||
k_pattern="b s n d",
|
||||
v_pattern="b s n d",
|
||||
out_pattern="b s n d",
|
||||
attn_mask=attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
output = attn.to_out[0](hidden_states)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ErnieImageAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
qk_norm: str = "rms_norm",
|
||||
out_bias: bool = True,
|
||||
eps: float = 1e-5,
|
||||
out_dim: int = None,
|
||||
elementwise_affine: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if qk_norm == "layer_norm":
|
||||
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
elif qk_norm == "rms_norm":
|
||||
self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'rms_norm'."
|
||||
)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
|
||||
self.processor = ErnieImageSingleStreamAttnProcessor()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(self, hidden_states, attention_mask, image_rotary_emb)
|
||||
|
||||
|
||||
class ErnieImageFeedForward(nn.Module):
|
||||
def __init__(self, hidden_size: int, ffn_hidden_size: int):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
|
||||
self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
|
||||
|
||||
|
||||
class ErnieImageRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class ErnieImageSharedAdaLNBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
ffn_hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
qk_layernorm: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.adaLN_sa_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
|
||||
self.self_attention = ErnieImageAttention(
|
||||
query_dim=hidden_size,
|
||||
dim_head=hidden_size // num_heads,
|
||||
heads=num_heads,
|
||||
qk_norm="rms_norm" if qk_layernorm else None,
|
||||
eps=eps,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
)
|
||||
self.adaLN_mlp_ln = ErnieImageRMSNorm(hidden_size, eps=eps)
|
||||
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
temb: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
|
||||
residual = x
|
||||
x = self.adaLN_sa_ln(x)
|
||||
x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
|
||||
x_bsh = x.permute(1, 0, 2)
|
||||
attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
|
||||
attn_out = attn_out.permute(1, 0, 2)
|
||||
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
|
||||
residual = x
|
||||
x = self.adaLN_mlp_ln(x)
|
||||
x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
|
||||
return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype)
|
||||
|
||||
|
||||
class ErnieImageAdaLNContinuous(nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps)
|
||||
self.linear = nn.Linear(hidden_size, hidden_size * 2)
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
|
||||
x = self.norm(x)
|
||||
x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
|
||||
return x
|
||||
|
||||
|
||||
class ErnieImageDiT(nn.Module):
|
||||
"""
|
||||
Ernie-Image DiT model for DiffSynth-Studio.
|
||||
|
||||
Architecture: SharedAdaLN + RoPE 3D + Joint Image-Text Attention.
|
||||
Internal format: [S, B, H] for transformer blocks, [B, S, H] for attention.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 4096,
|
||||
num_attention_heads: int = 32,
|
||||
num_layers: int = 36,
|
||||
ffn_hidden_size: int = 12288,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 128,
|
||||
patch_size: int = 1,
|
||||
text_in_dim: int = 3072,
|
||||
rope_theta: int = 256,
|
||||
rope_axes_dim: Tuple[int, int, int] = (32, 48, 48),
|
||||
eps: float = 1e-6,
|
||||
qk_layernorm: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.num_layers = num_layers
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.text_in_dim = text_in_dim
|
||||
|
||||
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size)
|
||||
self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None
|
||||
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0)
|
||||
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size)
|
||||
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size))
|
||||
nn.init.zeros_(self.adaLN_modulation[-1].weight)
|
||||
nn.init.zeros_(self.adaLN_modulation[-1].bias)
|
||||
self.layers = nn.ModuleList([
|
||||
ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps)
|
||||
self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
|
||||
nn.init.zeros_(self.final_linear.weight)
|
||||
nn.init.zeros_(self.final_linear.bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
text_bth: torch.Tensor,
|
||||
text_lens: torch.Tensor,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
) -> torch.Tensor:
|
||||
device, dtype = hidden_states.device, hidden_states.dtype
|
||||
B, C, H, W = hidden_states.shape
|
||||
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
|
||||
N_img = Hp * Wp
|
||||
|
||||
img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous()
|
||||
|
||||
if self.text_proj is not None and text_bth.numel() > 0:
|
||||
text_bth = self.text_proj(text_bth)
|
||||
Tmax = text_bth.shape[1]
|
||||
text_sbh = text_bth.transpose(0, 1).contiguous()
|
||||
|
||||
x = torch.cat([img_sbh, text_sbh], dim=0)
|
||||
S = x.shape[0]
|
||||
|
||||
text_ids = torch.cat([
|
||||
torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1),
|
||||
torch.zeros((B, Tmax, 2), device=device)
|
||||
], dim=-1) if Tmax > 0 else torch.zeros((B, 0, 3), device=device)
|
||||
grid_yx = torch.stack(
|
||||
torch.meshgrid(torch.arange(Hp, device=device, dtype=torch.float32),
|
||||
torch.arange(Wp, device=device, dtype=torch.float32), indexing="ij"),
|
||||
dim=-1
|
||||
).reshape(-1, 2)
|
||||
image_ids = torch.cat([
|
||||
text_lens.float().view(B, 1, 1).expand(-1, N_img, -1),
|
||||
grid_yx.view(1, N_img, 2).expand(B, -1, -1)
|
||||
], dim=-1)
|
||||
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
|
||||
|
||||
valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) if Tmax > 0 else torch.zeros((B, 0), device=device, dtype=torch.bool)
|
||||
attention_mask = torch.cat([
|
||||
torch.ones((B, N_img), device=device, dtype=torch.bool),
|
||||
valid_text
|
||||
], dim=1)[:, None, None, :]
|
||||
|
||||
sample = self.time_proj(timestep.to(dtype))
|
||||
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
|
||||
c = self.time_embedding(sample)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
||||
t.unsqueeze(0).expand(S, -1, -1).contiguous()
|
||||
for t in self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
]
|
||||
|
||||
for layer in self.layers:
|
||||
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
|
||||
if torch.is_grad_enabled() and use_gradient_checkpointing:
|
||||
x = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
x,
|
||||
rotary_pos_emb,
|
||||
temb,
|
||||
attention_mask,
|
||||
)
|
||||
else:
|
||||
x = layer(x, rotary_pos_emb, temb, attention_mask)
|
||||
|
||||
x = self.final_norm(x, c).type_as(x)
|
||||
patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous()
|
||||
output = patches.view(B, Hp, Wp, p, p, self.out_channels).permute(0, 5, 1, 3, 2, 4).contiguous().view(B, self.out_channels, H, W)
|
||||
|
||||
return output
|
||||
@@ -1,76 +0,0 @@
|
||||
"""
|
||||
Ernie-Image TextEncoder for DiffSynth-Studio.
|
||||
|
||||
Wraps transformers Ministral3Model to output text embeddings.
|
||||
Pattern: lazy import + manual config dict + torch.nn.Module wrapper.
|
||||
Only loads the text (language) model, ignoring vision components.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ErnieImageTextEncoder(torch.nn.Module):
|
||||
"""
|
||||
Text encoder using Ministral3Model (transformers).
|
||||
Only the text_config portion of the full Mistral3Model checkpoint.
|
||||
Uses the base model (no lm_head) since the checkpoint only has embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import Ministral3Config, Ministral3Model
|
||||
|
||||
text_config = {
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 1,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 2,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9216,
|
||||
"max_position_embeddings": 262144,
|
||||
"model_type": "ministral3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 26,
|
||||
"num_key_value_heads": 8,
|
||||
"pad_token_id": 11,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_parameters": {
|
||||
"beta_fast": 32.0,
|
||||
"beta_slow": 1.0,
|
||||
"factor": 16.0,
|
||||
"llama_4_scaling_beta": 0.1,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 16384,
|
||||
"rope_theta": 1000000.0,
|
||||
"rope_type": "yarn",
|
||||
"type": "yarn",
|
||||
},
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": True,
|
||||
"use_cache": True,
|
||||
"vocab_size": 131072,
|
||||
}
|
||||
config = Ministral3Config(**text_config)
|
||||
self.model = Ministral3Model(config)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
**kwargs,
|
||||
)
|
||||
return (outputs.hidden_states,)
|
||||
@@ -1,636 +0,0 @@
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
) -> torch.Tensor:
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
emb = scale * emb
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
return get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
scale=self.scale,
|
||||
)
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
self.act = nn.SiLU()
|
||||
time_embed_dim_out = out_dim if out_dim is not None else time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
||||
self.post_act = nn.SiLU() if post_act_fn == "silu" else None
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = hidden_size
|
||||
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
else:
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
self.approximate = approximate
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = F.gelu(hidden_states, approximate=self.approximate)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
final_dropout: bool = False,
|
||||
inner_dim=None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
# Build activation + projection matching diffusers pattern
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "gelu-approximate":
|
||||
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
||||
else:
|
||||
act_fn = GELU(dim, inner_dim, bias=bias)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
self.net.append(act_fn)
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _to_tuple(x, dim=2):
|
||||
if isinstance(x, int):
|
||||
return (x,) * dim
|
||||
elif len(x) == dim:
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||
|
||||
|
||||
def get_meshgrid_nd(start, *args, dim=2):
|
||||
if len(args) == 0:
|
||||
num = _to_tuple(start, dim=dim)
|
||||
start = (0,) * dim
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
start = _to_tuple(start, dim=dim)
|
||||
stop = _to_tuple(args[0], dim=dim)
|
||||
num = [stop[i] - start[i] for i in range(dim)]
|
||||
elif len(args) == 2:
|
||||
start = _to_tuple(start, dim=dim)
|
||||
stop = _to_tuple(args[0], dim=dim)
|
||||
num = _to_tuple(args[1], dim=dim)
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
axis_grid = []
|
||||
for i in range(dim):
|
||||
a, b, n = start[i], stop[i], num[i]
|
||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||
axis_grid.append(g)
|
||||
grid = torch.meshgrid(*axis_grid, indexing="ij")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
return grid
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis, x, head_first=False):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
if isinstance(freqs_cis, tuple):
|
||||
if head_first:
|
||||
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1])
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
else:
|
||||
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
else:
|
||||
if head_first:
|
||||
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
else:
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
def apply_rotary_emb(xq, xk, freqs_cis, head_first=False):
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(dim, pos, theta=10000.0, use_real=False, theta_rescale_factor=1.0, interpolation_factor=1.0):
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos).float()
|
||||
if theta_rescale_factor != 1.0:
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
freqs = torch.outer(pos * interpolation_factor, freqs)
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1)
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1)
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
return torch.polar(torch.ones_like(freqs), freqs)
|
||||
|
||||
|
||||
def get_nd_rotary_pos_embed(rope_dim_list, start, *args, theta=10000.0, use_real=False,
|
||||
txt_rope_size=None, theta_rescale_factor=1.0, interpolation_factor=1.0):
|
||||
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
|
||||
if isinstance(theta_rescale_factor, (int, float)):
|
||||
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
||||
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
||||
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
||||
if isinstance(interpolation_factor, (int, float)):
|
||||
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
||||
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
||||
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
||||
embs = []
|
||||
for i in range(len(rope_dim_list)):
|
||||
emb = get_1d_rotary_pos_embed(
|
||||
rope_dim_list[i], grid[i].reshape(-1), theta,
|
||||
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
|
||||
interpolation_factor=interpolation_factor[i],
|
||||
)
|
||||
embs.append(emb)
|
||||
if use_real:
|
||||
vis_emb = (torch.cat([emb[0] for emb in embs], dim=1), torch.cat([emb[1] for emb in embs], dim=1))
|
||||
else:
|
||||
vis_emb = torch.cat(embs, dim=1)
|
||||
if txt_rope_size is not None:
|
||||
embs_txt = []
|
||||
vis_max_ids = grid.view(-1).max().item()
|
||||
grid_txt = torch.arange(txt_rope_size) + vis_max_ids + 1
|
||||
for i in range(len(rope_dim_list)):
|
||||
emb = get_1d_rotary_pos_embed(
|
||||
rope_dim_list[i], grid_txt, theta,
|
||||
use_real=use_real, theta_rescale_factor=theta_rescale_factor[i],
|
||||
interpolation_factor=interpolation_factor[i],
|
||||
)
|
||||
embs_txt.append(emb)
|
||||
if use_real:
|
||||
txt_emb = (torch.cat([emb[0] for emb in embs_txt], dim=1), torch.cat([emb[1] for emb in embs_txt], dim=1))
|
||||
else:
|
||||
txt_emb = torch.cat(embs_txt, dim=1)
|
||||
else:
|
||||
txt_emb = None
|
||||
return vis_emb, txt_emb
|
||||
|
||||
|
||||
class ModulateWan(nn.Module):
|
||||
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
self.modulate_table = nn.Parameter(
|
||||
torch.zeros(1, factor, hidden_size, **factory_kwargs) / hidden_size**0.5,
|
||||
requires_grad=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if len(x.shape) != 3:
|
||||
x = x.unsqueeze(1)
|
||||
return [o.squeeze(1) for o in (self.modulate_table + x).chunk(self.factor, dim=1)]
|
||||
|
||||
|
||||
def modulate(x, shift=None, scale=None):
|
||||
if scale is None and shift is None:
|
||||
return x
|
||||
elif shift is None:
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
elif scale is None:
|
||||
return x + shift.unsqueeze(1)
|
||||
else:
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def apply_gate(x, gate=None, tanh=False):
|
||||
if gate is None:
|
||||
return x
|
||||
if tanh:
|
||||
return x * gate.unsqueeze(1).tanh()
|
||||
else:
|
||||
return x * gate.unsqueeze(1)
|
||||
|
||||
|
||||
def load_modulation(modulate_type: str, hidden_size: int, factor: int, dtype=None, device=None):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
if modulate_type == 'wanx':
|
||||
return ModulateWan(hidden_size, factor, **factory_kwargs)
|
||||
raise ValueError(f"Unknown modulation type: {modulate_type}. Only 'wanx' is supported.")
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if hasattr(self, "weight"):
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
|
||||
class MMDoubleStreamBlock(nn.Module):
|
||||
"""
|
||||
A multimodal dit block with separate modulation for
|
||||
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
|
||||
(Flux.1): https://github.com/black-forest-labs/flux
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
heads_num: int,
|
||||
mlp_width_ratio: float,
|
||||
mlp_act_type: str = "gelu_tanh",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dit_modulation_type: Optional[str] = "wanx",
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.dit_modulation_type = dit_modulation_type
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||
|
||||
self.img_mod = load_modulation(
|
||||
modulate_type=self.dit_modulation_type,
|
||||
hidden_size=hidden_size, factor=6, **factory_kwargs,
|
||||
)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
|
||||
self.img_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
self.img_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.img_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
|
||||
|
||||
self.txt_mod = load_modulation(
|
||||
modulate_type=self.dit_modulation_type,
|
||||
hidden_size=hidden_size, factor=6, **factory_kwargs,
|
||||
)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True, **factory_kwargs)
|
||||
self.txt_attn_q_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
self.txt_attn_k_norm = RMSNorm(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=True, **factory_kwargs)
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||
self.txt_mlp = FeedForward(hidden_size, inner_dim=mlp_hidden_dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
vis_freqs_cis: tuple = None,
|
||||
txt_freqs_cis: tuple = None,
|
||||
attn_kwargs: Optional[dict] = {},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
(
|
||||
img_mod1_shift, img_mod1_scale, img_mod1_gate,
|
||||
img_mod2_shift, img_mod2_scale, img_mod2_gate,
|
||||
) = self.img_mod(vec)
|
||||
(
|
||||
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate,
|
||||
txt_mod2_shift, txt_mod2_scale, txt_mod2_gate,
|
||||
) = self.txt_mod(vec)
|
||||
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
|
||||
img_qkv = self.img_attn_qkv(img_modulated)
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
||||
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
||||
|
||||
if vis_freqs_cis is not None:
|
||||
img_qq, img_kk = apply_rotary_emb(img_q, img_k, vis_freqs_cis, head_first=False)
|
||||
img_q, img_k = img_qq, img_kk
|
||||
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
|
||||
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
||||
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
||||
|
||||
if txt_freqs_cis is not None:
|
||||
raise NotImplementedError("RoPE text is not supported for inference")
|
||||
|
||||
q = torch.cat((img_q, txt_q), dim=1)
|
||||
k = torch.cat((img_k, txt_k), dim=1)
|
||||
v = torch.cat((img_v, txt_v), dim=1)
|
||||
|
||||
# Use DiffSynth unified attention
|
||||
attn_out = attention_forward(
|
||||
q, k, v,
|
||||
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
||||
)
|
||||
|
||||
attn_out = attn_out.flatten(2, 3)
|
||||
img_attn, txt_attn = attn_out[:, : img.shape[1]], attn_out[:, img.shape[1]:]
|
||||
|
||||
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
||||
img = img + apply_gate(
|
||||
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
|
||||
gate=img_mod2_gate,
|
||||
)
|
||||
|
||||
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
||||
txt = txt + apply_gate(
|
||||
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
|
||||
gate=txt_mod2_gate,
|
||||
)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_freq_dim: int,
|
||||
time_proj_dim: int,
|
||||
text_embed_dim: int,
|
||||
image_embed_dim: Optional[int] = None,
|
||||
pos_embed_seq_len: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||
self.act_fn = nn.SiLU()
|
||||
self.time_proj = nn.Linear(dim, time_proj_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||
|
||||
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
return temb, timestep_proj, encoder_hidden_states
|
||||
|
||||
|
||||
class JoyAIImageDiT(nn.Module):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: list = [1, 2, 2],
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
hidden_size: int = 4096,
|
||||
heads_num: int = 32,
|
||||
text_states_dim: int = 4096,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mm_double_blocks_depth: int = 40,
|
||||
rope_dim_list: List[int] = [16, 56, 56],
|
||||
rope_type: str = 'rope',
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dit_modulation_type: str = "wanx",
|
||||
theta: int = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.patch_size = patch_size
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
self.rope_dim_list = rope_dim_list
|
||||
self.dit_modulation_type = dit_modulation_type
|
||||
self.mm_double_blocks_depth = mm_double_blocks_depth
|
||||
self.rope_type = rope_type
|
||||
self.theta = theta
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
if hidden_size % heads_num != 0:
|
||||
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
|
||||
|
||||
self.img_in = nn.Conv3d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||
dim=hidden_size,
|
||||
time_freq_dim=256,
|
||||
time_proj_dim=hidden_size * 6,
|
||||
text_embed_dim=text_states_dim,
|
||||
)
|
||||
|
||||
self.double_blocks = nn.ModuleList([
|
||||
MMDoubleStreamBlock(
|
||||
self.hidden_size, self.heads_num,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
dit_modulation_type=self.dit_modulation_type,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for _ in range(mm_double_blocks_depth)
|
||||
])
|
||||
|
||||
self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(hidden_size, self.out_channels * math.prod(patch_size), **factory_kwargs)
|
||||
|
||||
def get_rotary_pos_embed(self, vis_rope_size, txt_rope_size=None):
|
||||
target_ndim = 3
|
||||
if len(vis_rope_size) != target_ndim:
|
||||
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
|
||||
head_dim = self.hidden_size // self.heads_num
|
||||
rope_dim_list = self.rope_dim_list
|
||||
if rope_dim_list is None:
|
||||
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||
assert sum(rope_dim_list) == head_dim
|
||||
vis_freqs, txt_freqs = get_nd_rotary_pos_embed(
|
||||
rope_dim_list, vis_rope_size,
|
||||
txt_rope_size=txt_rope_size if self.rope_type == 'mrope' else None,
|
||||
theta=self.theta, use_real=True, theta_rescale_factor=1,
|
||||
)
|
||||
return vis_freqs, txt_freqs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
encoder_hidden_states_mask: torch.Tensor = None,
|
||||
return_dict: bool = True,
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
is_multi_item = (len(hidden_states.shape) == 6)
|
||||
num_items = 0
|
||||
if is_multi_item:
|
||||
num_items = hidden_states.shape[1]
|
||||
if num_items > 1:
|
||||
assert self.patch_size[0] == 1, "For multi-item input, patch_size[0] must be 1"
|
||||
hidden_states = torch.cat([hidden_states[:, -1:], hidden_states[:, :-1]], dim=1)
|
||||
hidden_states = rearrange(hidden_states, 'b n c t h w -> b c (n t) h w')
|
||||
|
||||
batch_size, _, ot, oh, ow = hidden_states.shape
|
||||
tt, th, tw = ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2]
|
||||
|
||||
if encoder_hidden_states_mask is None:
|
||||
encoder_hidden_states_mask = torch.ones(
|
||||
(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]),
|
||||
dtype=torch.bool,
|
||||
).to(encoder_hidden_states.device)
|
||||
|
||||
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
|
||||
temb, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
|
||||
if vec.shape[-1] > self.hidden_size:
|
||||
vec = vec.unflatten(1, (6, -1))
|
||||
|
||||
txt_seq_len = txt.shape[1]
|
||||
img_seq_len = img.shape[1]
|
||||
|
||||
vis_freqs_cis, txt_freqs_cis = self.get_rotary_pos_embed(
|
||||
vis_rope_size=(tt, th, tw),
|
||||
txt_rope_size=txt_seq_len if self.rope_type == 'mrope' else None,
|
||||
)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
img=img, txt=txt, vec=vec,
|
||||
vis_freqs_cis=vis_freqs_cis, txt_freqs_cis=txt_freqs_cis,
|
||||
attn_kwargs={},
|
||||
)
|
||||
|
||||
img_len = img.shape[1]
|
||||
x = torch.cat((img, txt), 1)
|
||||
img = x[:, :img_len, ...]
|
||||
|
||||
img = self.proj_out(self.norm_out(img))
|
||||
img = self.unpatchify(img, tt, th, tw)
|
||||
|
||||
if is_multi_item:
|
||||
img = rearrange(img, 'b c (n t) h w -> b n c t h w', n=num_items)
|
||||
if num_items > 1:
|
||||
img = torch.cat([img[:, 1:], img[:, :1]], dim=1)
|
||||
|
||||
return img
|
||||
|
||||
def unpatchify(self, x, t, h, w):
|
||||
c = self.out_channels
|
||||
pt, ph, pw = self.patch_size
|
||||
assert t * h * w == x.shape[1]
|
||||
x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
|
||||
x = torch.einsum("nthwopqc->nctohpwq", x)
|
||||
return x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
||||
@@ -1,82 +0,0 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class JoyAIImageTextEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
|
||||
|
||||
config = Qwen3VLConfig(
|
||||
text_config={
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 4096,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 12288,
|
||||
"max_position_embeddings": 262144,
|
||||
"model_type": "qwen3_vl_text",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-6,
|
||||
"rope_scaling": {
|
||||
"mrope_interleaved": True,
|
||||
"mrope_section": [24, 20, 20],
|
||||
"rope_type": "default",
|
||||
},
|
||||
"rope_theta": 5000000,
|
||||
"use_cache": True,
|
||||
"vocab_size": 151936,
|
||||
},
|
||||
vision_config={
|
||||
"deepstack_visual_indexes": [8, 16, 24],
|
||||
"depth": 27,
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"in_channels": 3,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "qwen3_vl",
|
||||
"num_heads": 16,
|
||||
"num_position_embeddings": 2304,
|
||||
"out_hidden_size": 4096,
|
||||
"patch_size": 16,
|
||||
"spatial_merge_size": 2,
|
||||
"temporal_patch_size": 2,
|
||||
},
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
vision_start_token_id=151652,
|
||||
vision_end_token_id=151653,
|
||||
tie_word_embeddings=False,
|
||||
)
|
||||
|
||||
self.model = Qwen3VLForConditionalGeneration(config)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
pre_norm_output = [None]
|
||||
def hook_fn(module, args, kwargs_output=None):
|
||||
pre_norm_output[0] = args[0]
|
||||
self.model.model.language_model.norm.register_forward_hook(hook_fn)
|
||||
_ = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
**kwargs,
|
||||
)
|
||||
return pre_norm_output[0]
|
||||
@@ -1,584 +0,0 @@
|
||||
"""
|
||||
ACE-Step Pipeline for DiffSynth-Studio.
|
||||
|
||||
Text-to-Music generation pipeline using ACE-Step 1.5 model.
|
||||
"""
|
||||
import re, torch
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from tqdm import tqdm
|
||||
import random, math
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
|
||||
from ..models.ace_step_dit import AceStepDiTModel
|
||||
from ..models.ace_step_conditioner import AceStepConditionEncoder
|
||||
from ..models.ace_step_text_encoder import AceStepTextEncoder
|
||||
from ..models.ace_step_vae import AceStepVAE
|
||||
from ..models.ace_step_tokenizer import AceStepTokenizer
|
||||
|
||||
|
||||
class AceStepPipeline(BasePipeline):
|
||||
"""Pipeline for ACE-Step text-to-music generation."""
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device,
|
||||
torch_dtype=torch_dtype,
|
||||
height_division_factor=1,
|
||||
width_division_factor=1,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("ACE-Step")
|
||||
self.text_encoder: AceStepTextEncoder = None
|
||||
self.conditioner: AceStepConditionEncoder = None
|
||||
self.dit: AceStepDiTModel = None
|
||||
self.vae: AceStepVAE = None
|
||||
self.tokenizer_model: AceStepTokenizer = None
|
||||
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
AceStepUnit_TaskTypeChecker(),
|
||||
AceStepUnit_PromptEmbedder(),
|
||||
AceStepUnit_ReferenceAudioEmbedder(),
|
||||
AceStepUnit_ContextLatentBuilder(),
|
||||
AceStepUnit_ConditionEmbedder(),
|
||||
AceStepUnit_NoiseInitializer(),
|
||||
AceStepUnit_InputAudioEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_ace_step
|
||||
self.compilable_models = ["dit"]
|
||||
|
||||
self.sample_rate = 48000
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
text_tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
vram_limit: float = None,
|
||||
):
|
||||
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
pipe.text_encoder = model_pool.fetch_model("ace_step_text_encoder")
|
||||
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
|
||||
pipe.dit = model_pool.fetch_model("ace_step_dit")
|
||||
pipe.vae = model_pool.fetch_model("ace_step_vae")
|
||||
pipe.vae.remove_weight_norm()
|
||||
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
|
||||
|
||||
if text_tokenizer_config is not None:
|
||||
text_tokenizer_config.download_if_necessary()
|
||||
from transformers import AutoTokenizer
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_config.path)
|
||||
if silence_latent_config is not None:
|
||||
silence_latent_config.download_if_necessary()
|
||||
pipe.silence_latent = torch.load(silence_latent_config.path, weights_only=True).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
cfg_scale: float = 1.0,
|
||||
# Lyrics
|
||||
lyrics: str = "",
|
||||
# Task type
|
||||
task_type: Optional[str] = "text2music",
|
||||
# Reference audio
|
||||
reference_audios: List[torch.Tensor] = None,
|
||||
# Source audio
|
||||
src_audio: torch.Tensor = None,
|
||||
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
|
||||
audio_cover_strength: float = 1.0,
|
||||
# Audio codes
|
||||
audio_code_string: Optional[str] = None,
|
||||
# Inpainting
|
||||
repainting_ranges: Optional[List[Tuple[float, float]]] = None,
|
||||
repainting_strength: float = 1.0,
|
||||
# Shape
|
||||
duration: int = 60,
|
||||
# Audio Meta
|
||||
bpm: Optional[int] = 100,
|
||||
keyscale: Optional[str] = "B minor",
|
||||
timesignature: Optional[str] = "4",
|
||||
vocal_language: Optional[str] = "unknown",
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 8,
|
||||
# Scheduler-specific parameters
|
||||
shift: float = 1.0,
|
||||
# Progress
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt, "positive": True}
|
||||
inputs_nega = {"positive": False}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"lyrics": lyrics,
|
||||
"task_type": task_type,
|
||||
"reference_audios": reference_audios,
|
||||
"src_audio": src_audio, "audio_cover_strength": audio_cover_strength, "audio_code_string": audio_code_string,
|
||||
"repainting_ranges": repainting_ranges, "repainting_strength": repainting_strength,
|
||||
"duration": duration,
|
||||
"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "vocal_language": vocal_language,
|
||||
"seed": seed,
|
||||
"rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"shift": shift,
|
||||
}
|
||||
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||
)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
self.switch_noncover_condition(inputs_shared, inputs_posi, inputs_nega, progress_id)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id,
|
||||
)
|
||||
inputs_shared["latents"] = self.step(
|
||||
self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
|
||||
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared,
|
||||
)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
|
||||
latents = inputs_shared["latents"].transpose(1, 2)
|
||||
vae_output = self.vae.decode(latents)
|
||||
audio_output = self.normalize_audio(vae_output, target_db=-1.0)
|
||||
audio = self.output_audio_format_check(audio_output)
|
||||
self.load_models_to_device([])
|
||||
return audio
|
||||
|
||||
def normalize_audio(self, audio: torch.Tensor, target_db: float = -1.0) -> torch.Tensor:
|
||||
peak = torch.max(torch.abs(audio))
|
||||
if peak < 1e-6:
|
||||
return audio
|
||||
target_amp = 10 ** (target_db / 20.0)
|
||||
gain = target_amp / peak
|
||||
return audio * gain
|
||||
|
||||
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
|
||||
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0:
|
||||
return
|
||||
if inputs_shared.get("shared_noncover", None) is None:
|
||||
return
|
||||
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
|
||||
if progress_id >= cover_steps:
|
||||
inputs_shared.update(inputs_shared.pop("shared_noncover", {}))
|
||||
inputs_posi.update(inputs_shared.pop("posi_noncover", {}))
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_nega.update(inputs_shared.pop("nega_noncover", {}))
|
||||
|
||||
|
||||
class AceStepUnit_TaskTypeChecker(PipelineUnit):
|
||||
"""Check and compute sequence length from duration."""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("task_type", "src_audio", "repainting_ranges", "audio_code_string"),
|
||||
output_params=("task_type",),
|
||||
)
|
||||
|
||||
def process(self, pipe, task_type, src_audio, repainting_ranges, audio_code_string):
|
||||
assert task_type in ["text2music", "cover", "repaint"], f"Unsupported task_type: {task_type}"
|
||||
if task_type == "cover":
|
||||
assert (src_audio is not None) or (audio_code_string is not None), "For cover task, either src_audio or audio_code_string must be provided."
|
||||
elif task_type == "repaint":
|
||||
assert src_audio is not None, "For repaint task, src_audio must be provided."
|
||||
assert repainting_ranges is not None and len(repainting_ranges) > 0, "For repaint task, inpainting_ranges must be provided and non-empty."
|
||||
return {}
|
||||
|
||||
|
||||
class AceStepUnit_PromptEmbedder(PipelineUnit):
|
||||
SFT_GEN_PROMPT = "# Instruction\n{}\n\n# Caption\n{}\n\n# Metas\n{}<|endoftext|>\n"
|
||||
INSTRUCTION_MAP = {
|
||||
"text2music": "Fill the audio semantic mask based on the given conditions:",
|
||||
"cover": "Generate audio semantic tokens based on the given conditions:",
|
||||
"repaint": "Repaint the mask area based on the given conditions:",
|
||||
"extract": "Extract the {TRACK_NAME} track from the audio:",
|
||||
"extract_default": "Extract the track from the audio:",
|
||||
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
|
||||
"lego_default": "Generate the track based on the audio context:",
|
||||
"complete": "Complete the input track with {TRACK_CLASSES}:",
|
||||
"complete_default": "Complete the input track:",
|
||||
}
|
||||
LYRIC_PROMPT = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|>"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
||||
input_params_nega={"prompt": "prompt", "positive": "positive"},
|
||||
input_params=("lyrics", "duration", "bpm", "keyscale", "timesignature", "vocal_language", "task_type"),
|
||||
output_params=("text_hidden_states", "text_attention_mask", "lyric_hidden_states", "lyric_attention_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def _encode_text(self, pipe, text, max_length=256):
|
||||
"""Encode text using Qwen3-Embedding → [B, T, 1024]."""
|
||||
text_inputs = pipe.tokenizer(
|
||||
text,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = text_inputs.input_ids.to(pipe.device)
|
||||
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
|
||||
hidden_states = pipe.text_encoder(input_ids, attention_mask)
|
||||
return hidden_states, attention_mask
|
||||
|
||||
def _encode_lyrics(self, pipe, lyric_text, max_length=2048):
|
||||
text_inputs = pipe.tokenizer(
|
||||
lyric_text,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = text_inputs.input_ids.to(pipe.device)
|
||||
attention_mask = text_inputs.attention_mask.bool().to(pipe.device)
|
||||
hidden_states = pipe.text_encoder.model.embed_tokens(input_ids)
|
||||
return hidden_states, attention_mask
|
||||
|
||||
def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str:
|
||||
bpm = meta_dict.get("bpm", "N/A")
|
||||
timesignature = meta_dict.get("timesignature", "N/A")
|
||||
keyscale = meta_dict.get("keyscale", "N/A")
|
||||
duration = meta_dict.get("duration", 30)
|
||||
duration = f"{int(duration)} seconds"
|
||||
return (
|
||||
f"- bpm: {bpm}\n"
|
||||
f"- timesignature: {timesignature}\n"
|
||||
f"- keyscale: {keyscale}\n"
|
||||
f"- duration: {duration}\n"
|
||||
)
|
||||
|
||||
def process(self, pipe, prompt, positive, lyrics, duration, bpm, keyscale, timesignature, vocal_language, task_type):
|
||||
if not positive:
|
||||
return {}
|
||||
pipe.load_models_to_device(['text_encoder'])
|
||||
meta_dict = {"bpm": bpm, "keyscale": keyscale, "timesignature": timesignature, "duration": duration}
|
||||
INSTRUCTION = self.INSTRUCTION_MAP.get(task_type, self.INSTRUCTION_MAP["text2music"])
|
||||
prompt = self.SFT_GEN_PROMPT.format(INSTRUCTION, prompt, self._dict_to_meta_string(meta_dict))
|
||||
text_hidden_states, text_attention_mask = self._encode_text(pipe, prompt, max_length=256)
|
||||
|
||||
lyric_text = self.LYRIC_PROMPT.format(vocal_language, lyrics)
|
||||
lyric_hidden_states, lyric_attention_mask = self._encode_lyrics(pipe, lyric_text, max_length=2048)
|
||||
|
||||
# TODO: remove this
|
||||
newtext = prompt + "\n\n" + lyric_text
|
||||
return {
|
||||
"text_hidden_states": text_hidden_states,
|
||||
"text_attention_mask": text_attention_mask,
|
||||
"lyric_hidden_states": lyric_hidden_states,
|
||||
"lyric_attention_mask": lyric_attention_mask,
|
||||
}
|
||||
|
||||
|
||||
class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("reference_audios",),
|
||||
output_params=("reference_latents", "refer_audio_order_mask"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe, reference_audios):
|
||||
if reference_audios is not None:
|
||||
pipe.load_models_to_device(['vae'])
|
||||
reference_audios = [
|
||||
self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
for reference_audio in reference_audios
|
||||
]
|
||||
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
|
||||
else:
|
||||
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
|
||||
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, reference_audios)
|
||||
return {"reference_latents": reference_latents, "refer_audio_order_mask": refer_audio_order_mask}
|
||||
|
||||
def process_reference_audio(self, audio) -> Optional[torch.Tensor]:
|
||||
if audio.ndim == 3 and audio.shape[0] == 1:
|
||||
audio = audio.squeeze(0)
|
||||
target_frames = 30 * 48000
|
||||
segment_frames = 10 * 48000
|
||||
if audio.shape[-1] < target_frames:
|
||||
repeat_times = math.ceil(target_frames / audio.shape[-1])
|
||||
audio = audio.repeat(1, repeat_times)
|
||||
total_frames = audio.shape[-1]
|
||||
segment_size = total_frames // 3
|
||||
front_start = random.randint(0, max(0, segment_size - segment_frames))
|
||||
front_audio = audio[:, front_start:front_start + segment_frames]
|
||||
middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames))
|
||||
middle_audio = audio[:, middle_start:middle_start + segment_frames]
|
||||
back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames))
|
||||
back_audio = audio[:, back_start:back_start + segment_frames]
|
||||
return torch.cat([front_audio, middle_audio, back_audio], dim=-1).unsqueeze(0)
|
||||
|
||||
def infer_refer_latent(self, pipe, refer_audioss: List[List[torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Infer packed reference-audio latents and order mask."""
|
||||
refer_audio_order_mask = []
|
||||
refer_audio_latents = []
|
||||
for batch_idx, refer_audios in enumerate(refer_audioss):
|
||||
if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0):
|
||||
refer_audio_latent = pipe.silence_latent[:, :750, :]
|
||||
refer_audio_latents.append(refer_audio_latent)
|
||||
refer_audio_order_mask.append(batch_idx)
|
||||
else:
|
||||
for refer_audio in refer_audios:
|
||||
refer_audio_latent = pipe.vae.encode(refer_audio).transpose(1, 2).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
refer_audio_latents.append(refer_audio_latent)
|
||||
refer_audio_order_mask.append(batch_idx)
|
||||
refer_audio_latents = torch.cat(refer_audio_latents, dim=0)
|
||||
refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=pipe.device, dtype=torch.long)
|
||||
return refer_audio_latents, refer_audio_order_mask
|
||||
|
||||
|
||||
class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
output_params=("encoder_hidden_states", "encoder_attention_mask"),
|
||||
onload_model_names=("conditioner",),
|
||||
)
|
||||
|
||||
def process(self, pipe, inputs_shared, inputs_posi, inputs_nega):
|
||||
pipe.load_models_to_device(['conditioner'])
|
||||
encoder_hidden_states, encoder_attention_mask = pipe.conditioner(
|
||||
text_hidden_states=inputs_posi.get("text_hidden_states", None),
|
||||
text_attention_mask=inputs_posi.get("text_attention_mask", None),
|
||||
lyric_hidden_states=inputs_posi.get("lyric_hidden_states", None),
|
||||
lyric_attention_mask=inputs_posi.get("lyric_attention_mask", None),
|
||||
reference_latents=inputs_shared.get("reference_latents", None),
|
||||
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
|
||||
)
|
||||
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
|
||||
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(
|
||||
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
|
||||
)
|
||||
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
|
||||
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
|
||||
pipe, inputs_posi["prompt"], True, inputs_shared["lyrics"], inputs_shared["duration"],
|
||||
inputs_shared["bpm"], inputs_shared["keyscale"], inputs_shared["timesignature"],
|
||||
inputs_shared["vocal_language"], "text2music")
|
||||
encoder_hidden_states_noncover, encoder_attention_mask_noncover = pipe.conditioner(
|
||||
**hidden_states_noncover,
|
||||
reference_latents=inputs_shared.get("reference_latents", None),
|
||||
refer_audio_order_mask=inputs_shared.get("refer_audio_order_mask", None),
|
||||
)
|
||||
duration = inputs_shared["context_latents"].shape[1] * 1920 / pipe.vae.sampling_rate
|
||||
context_latents_noncover = AceStepUnit_ContextLatentBuilder().process(pipe, duration, None, None)["context_latents"]
|
||||
inputs_shared["shared_noncover"] = {"context_latents": context_latents_noncover}
|
||||
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_shared["nega_noncover"] = {
|
||||
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(
|
||||
dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device,
|
||||
),
|
||||
"encoder_attention_mask": encoder_attention_mask_noncover,
|
||||
}
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("duration", "src_audio", "audio_code_string", "task_type", "repainting_ranges", "repainting_strength"),
|
||||
output_params=("context_latents", "src_latents", "chunk_masks", "attention_mask"),
|
||||
onload_model_names=("vae", "tokenizer_model",),
|
||||
)
|
||||
|
||||
def _get_silence_latent_slice(self, pipe, length: int) -> torch.Tensor:
|
||||
available = pipe.silence_latent.shape[1]
|
||||
if length <= available:
|
||||
return pipe.silence_latent[0, :length, :]
|
||||
repeats = (length + available - 1) // available
|
||||
tiled = pipe.silence_latent[0].repeat(repeats, 1)
|
||||
return tiled[:length, :]
|
||||
|
||||
def tokenize(self, tokenizer, x, silence_latent, pool_window_size):
|
||||
if x.shape[1] % pool_window_size != 0:
|
||||
pad_len = pool_window_size - (x.shape[1] % pool_window_size)
|
||||
x = torch.cat([x, silence_latent[:1,:pad_len].repeat(x.shape[0],1,1)], dim=1)
|
||||
x = rearrange(x, 'n (t_patch p) d -> n t_patch p d', p=pool_window_size)
|
||||
quantized, indices = tokenizer(x)
|
||||
return quantized
|
||||
|
||||
@staticmethod
|
||||
def _parse_audio_code_string(code_str: str) -> list:
|
||||
"""Extract integer audio codes from tokens like <|audio_code_123|>."""
|
||||
if not code_str:
|
||||
return []
|
||||
try:
|
||||
codes = []
|
||||
max_audio_code = 63999
|
||||
for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str):
|
||||
code_value = int(x)
|
||||
codes.append(max(0, min(code_value, max_audio_code)))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid audio_code_string format: {e}")
|
||||
return codes
|
||||
|
||||
def pad_src_audio(self, pipe, src_audio, task_type, repainting_ranges):
|
||||
if task_type != "repaint" or repainting_ranges is None:
|
||||
return src_audio, repainting_ranges, None, None
|
||||
min_left = min([start for start, end in repainting_ranges])
|
||||
max_right = max([end for start, end in repainting_ranges])
|
||||
total_length = src_audio.shape[-1] // pipe.vae.sampling_rate
|
||||
pad_left = max(0, -min_left)
|
||||
pad_right = max(0, max_right - total_length)
|
||||
if pad_left > 0 or pad_right > 0:
|
||||
padding_frames_left, padding_frames_right = pad_left * pipe.vae.sampling_rate, pad_right * pipe.vae.sampling_rate
|
||||
src_audio = F.pad(src_audio, (padding_frames_left, padding_frames_right), value=0.0)
|
||||
repainting_ranges = [(start + pad_left, end + pad_left) for start, end in repainting_ranges]
|
||||
return src_audio, repainting_ranges, pad_left, pad_right
|
||||
|
||||
def parse_repaint_masks(self, pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right):
|
||||
if task_type != "repaint" or repainting_ranges is None:
|
||||
return None, src_latents
|
||||
# let repainting area be repainting_strength, non-repainting area be 0.0, and blend at the boundary with cf_frames.
|
||||
max_latent_length = src_latents.shape[1]
|
||||
denoise_mask = torch.zeros((1, max_latent_length, 1), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
for start, end in repainting_ranges:
|
||||
start_frame = start * pipe.vae.sampling_rate // 1920
|
||||
end_frame = end * pipe.vae.sampling_rate // 1920
|
||||
denoise_mask[:, start_frame:end_frame, :] = repainting_strength
|
||||
# set padding areas to 1.0 (full repaint) to avoid artifacts at the boundaries caused by padding
|
||||
pad_left_frames = pad_left * pipe.vae.sampling_rate // 1920
|
||||
pad_right_frames = pad_right * pipe.vae.sampling_rate // 1920
|
||||
denoise_mask[:, :pad_left_frames, :] = 1
|
||||
denoise_mask[:, max_latent_length - pad_right_frames:, :] = 1
|
||||
|
||||
silent_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||
src_latents = src_latents * (1 - denoise_mask) + silent_latents * denoise_mask
|
||||
return denoise_mask, src_latents
|
||||
|
||||
def process(self, pipe, duration, src_audio, audio_code_string, task_type=None, repainting_ranges=None, repainting_strength=None):
|
||||
# get src_latents from audio_code_string > src_audio > silence
|
||||
source_latents = None
|
||||
denoise_mask = None
|
||||
if audio_code_string is not None:
|
||||
# use audio_cede_string to get src_latents.
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
code_ids = self._parse_audio_code_string(audio_code_string)
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer.to(device=pipe.device)
|
||||
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
|
||||
codes = quantizer.get_codes_from_indices(indices)
|
||||
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
|
||||
quantized = quantizer.project_out(quantized)
|
||||
src_latents = pipe.tokenizer_model.detokenizer(quantized).to(pipe.device)
|
||||
max_latent_length = src_latents.shape[1]
|
||||
elif src_audio is not None:
|
||||
# use src_audio to get src_latents.
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
src_audio = src_audio.unsqueeze(0) if src_audio.dim() == 2 else src_audio
|
||||
src_audio = torch.clamp(src_audio, -1.0, 1.0)
|
||||
|
||||
src_audio, repainting_ranges, pad_left, pad_right = self.pad_src_audio(pipe, src_audio, task_type, repainting_ranges)
|
||||
|
||||
src_latents = pipe.vae.encode(src_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
|
||||
source_latents = src_latents # cache for potential use in audio inpainting tasks
|
||||
denoise_mask, src_latents = self.parse_repaint_masks(pipe, src_latents, task_type, repainting_ranges, repainting_strength, pad_left, pad_right)
|
||||
if task_type == "cover":
|
||||
lm_hints_5Hz = self.tokenize(pipe.tokenizer_model.tokenizer, src_latents, pipe.silence_latent, pipe.tokenizer_model.tokenizer.pool_window_size)
|
||||
src_latents = pipe.tokenizer_model.detokenizer(lm_hints_5Hz)
|
||||
max_latent_length = src_latents.shape[1]
|
||||
else:
|
||||
# use silence latents.
|
||||
max_latent_length = int(duration * pipe.sample_rate // 1920)
|
||||
src_latents = self._get_silence_latent_slice(pipe, max_latent_length).unsqueeze(0)
|
||||
chunk_masks = torch.ones((1, max_latent_length, src_latents.shape[-1]), dtype=torch.bool, device=pipe.device)
|
||||
attention_mask = torch.ones((1, max_latent_length), device=src_latents.device, dtype=pipe.torch_dtype)
|
||||
context_latents = torch.cat([src_latents, chunk_masks], dim=-1)
|
||||
return {"context_latents": context_latents, "attention_mask": attention_mask, "src_latents": source_latents, "denoise_mask": denoise_mask}
|
||||
|
||||
|
||||
class AceStepUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("context_latents", "seed", "rand_device", "src_latents"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe, context_latents, seed, rand_device, src_latents):
|
||||
src_latents_shape = (context_latents.shape[0], context_latents.shape[1], context_latents.shape[-1] // 2)
|
||||
noise = pipe.generate_noise(src_latents_shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
if src_latents is not None:
|
||||
noise = pipe.scheduler.add_noise(src_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AceStepUnit_InputAudioEmbedder(PipelineUnit):
|
||||
"""Only for training."""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("noise", "input_audio"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",),
|
||||
)
|
||||
|
||||
def process(self, pipe, noise, input_audio):
|
||||
if input_audio is None:
|
||||
return {"latents": noise}
|
||||
if pipe.scheduler.training:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_audio, sample_rate = input_audio
|
||||
input_audio = torch.clamp(input_audio, -1.0, 1.0)
|
||||
if input_audio.dim() == 2:
|
||||
input_audio = input_audio.unsqueeze(0)
|
||||
input_latents = pipe.vae.encode(input_audio.to(dtype=pipe.torch_dtype, device=pipe.device)).transpose(1, 2)
|
||||
# prevent potential size mismatch between context_latents and input_latents by cropping input_latents to the same temporal length as noise
|
||||
input_latents = input_latents[:, :noise.shape[1]]
|
||||
return {"input_latents": input_latents}
|
||||
|
||||
|
||||
def model_fn_ace_step(
|
||||
dit: AceStepDiTModel,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
context_latents=None,
|
||||
attention_mask=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
decoder_outputs = dit(
|
||||
hidden_states=latents,
|
||||
timestep=timestep,
|
||||
timestep_r=timestep,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
context_latents=context_latents,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)[0]
|
||||
return decoder_outputs
|
||||
@@ -1,266 +0,0 @@
|
||||
"""
|
||||
ERNIE-Image Text-to-Image Pipeline for DiffSynth-Studio.
|
||||
|
||||
Architecture: SharedAdaLN DiT + RoPE 3D + Joint Image-Text Attention.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import Union, Optional
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from ..models.ernie_image_text_encoder import ErnieImageTextEncoder
|
||||
from ..models.ernie_image_dit import ErnieImageDiT
|
||||
from ..models.flux2_vae import Flux2VAE
|
||||
|
||||
|
||||
class ErnieImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("ERNIE-Image")
|
||||
self.text_encoder: ErnieImageTextEncoder = None
|
||||
self.dit: ErnieImageDiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
ErnieImageUnit_ShapeChecker(),
|
||||
ErnieImageUnit_PromptEmbedder(),
|
||||
ErnieImageUnit_NoiseInitializer(),
|
||||
ErnieImageUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_ernie_image
|
||||
self.compilable_models = ["dit"]
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit: float = None,
|
||||
):
|
||||
pipe = ErnieImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
pipe.text_encoder = model_pool.fetch_model("ernie_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("ernie_image_dit")
|
||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 4.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cuda",
|
||||
# Steps
|
||||
num_inference_steps: int = 50,
|
||||
sigma_shift: float = 3.0,
|
||||
# Progress bar
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, shift=sigma_shift)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_shared = {
|
||||
"height": height, "width": width, "seed": seed,
|
||||
"cfg_scale": cfg_scale, "num_inference_steps": num_inference_steps,
|
||||
"rand_device": rand_device,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
latents = inputs_shared["latents"]
|
||||
image = self.vae.decode(latents)
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
|
||||
class ErnieImageUnit_ShapeChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width"),
|
||||
output_params=("height", "width"),
|
||||
)
|
||||
|
||||
def process(self, pipe: ErnieImagePipeline, height, width):
|
||||
height, width = pipe.check_resize_height_width(height, width)
|
||||
return {"height": height, "width": width}
|
||||
|
||||
|
||||
class ErnieImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds", "prompt_embeds_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def encode_prompt(self, pipe: ErnieImagePipeline, prompt):
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
text_hiddens = []
|
||||
text_lens_list = []
|
||||
for p in prompt:
|
||||
ids = pipe.tokenizer(
|
||||
p,
|
||||
add_special_tokens=True,
|
||||
truncation=True,
|
||||
padding=False,
|
||||
)["input_ids"]
|
||||
|
||||
if len(ids) == 0:
|
||||
if pipe.tokenizer.bos_token_id is not None:
|
||||
ids = [pipe.tokenizer.bos_token_id]
|
||||
else:
|
||||
ids = [0]
|
||||
|
||||
input_ids = torch.tensor([ids], device=pipe.device)
|
||||
outputs = pipe.text_encoder(
|
||||
input_ids=input_ids,
|
||||
)
|
||||
# Text encoder returns tuple of (hidden_states_tuple,) where each layer's hidden state is included
|
||||
all_hidden_states = outputs[0]
|
||||
hidden = all_hidden_states[-2][0] # [T, H] - second to last layer
|
||||
text_hiddens.append(hidden)
|
||||
text_lens_list.append(hidden.shape[0])
|
||||
|
||||
# Pad to uniform length
|
||||
if len(text_hiddens) == 0:
|
||||
text_in_dim = pipe.text_encoder.config.hidden_size if hasattr(pipe.text_encoder, 'config') else 3072
|
||||
return {
|
||||
"prompt_embeds": torch.zeros((0, 0, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype),
|
||||
"prompt_embeds_mask": torch.zeros((0,), device=pipe.device, dtype=torch.long),
|
||||
}
|
||||
|
||||
normalized = [th.to(pipe.device).to(pipe.torch_dtype) for th in text_hiddens]
|
||||
text_lens = torch.tensor([t.shape[0] for t in normalized], device=pipe.device, dtype=torch.long)
|
||||
Tmax = int(text_lens.max().item())
|
||||
text_in_dim = normalized[0].shape[1]
|
||||
text_bth = torch.zeros((len(normalized), Tmax, text_in_dim), device=pipe.device, dtype=pipe.torch_dtype)
|
||||
for i, t in enumerate(normalized):
|
||||
text_bth[i, :t.shape[0], :] = t
|
||||
|
||||
return {"prompt_embeds": text_bth, "prompt_embeds_mask": text_lens}
|
||||
|
||||
def process(self, pipe: ErnieImagePipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if pipe.text_encoder is not None:
|
||||
return self.encode_prompt(pipe, prompt)
|
||||
return {}
|
||||
|
||||
|
||||
class ErnieImageUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: ErnieImagePipeline, height, width, seed, rand_device):
|
||||
latent_h = height // pipe.height_division_factor
|
||||
latent_w = width // pipe.width_division_factor
|
||||
latent_channels = pipe.dit.in_channels
|
||||
|
||||
# Use pipeline device if rand_device is not specified
|
||||
if rand_device is None:
|
||||
rand_device = str(pipe.device)
|
||||
|
||||
noise = pipe.generate_noise(
|
||||
(1, latent_channels, latent_h, latent_w),
|
||||
seed=seed,
|
||||
rand_device=rand_device,
|
||||
rand_torch_dtype=pipe.torch_dtype,
|
||||
)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class ErnieImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: ErnieImagePipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
# T2I path: use noise directly as initial latents
|
||||
return {"latents": noise, "input_latents": None}
|
||||
|
||||
# I2I path: VAE encode input image
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
input_latents = pipe.vae.encode(image)
|
||||
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
# In inference mode, add noise to encoded latents
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
def model_fn_ernie_image(
|
||||
dit: ErnieImageDiT,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
output = dit(
|
||||
hidden_states=latents,
|
||||
timestep=timestep,
|
||||
text_bth=prompt_embeds,
|
||||
text_lens=prompt_embeds_mask,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
return output
|
||||
@@ -1,282 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from typing import Union, Optional
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from ..models.joyai_image_dit import JoyAIImageDiT
|
||||
from ..models.joyai_image_text_encoder import JoyAIImageTextEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
|
||||
class JoyAIImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("Wan")
|
||||
self.text_encoder: JoyAIImageTextEncoder = None
|
||||
self.dit: JoyAIImageDiT = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.processor = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
|
||||
self.units = [
|
||||
JoyAIImageUnit_ShapeChecker(),
|
||||
JoyAIImageUnit_EditImageEmbedder(),
|
||||
JoyAIImageUnit_PromptEmbedder(),
|
||||
JoyAIImageUnit_NoiseInitializer(),
|
||||
JoyAIImageUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_joyai_image
|
||||
self.compilable_models = ["dit"]
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
# Processor
|
||||
processor_config: ModelConfig = None,
|
||||
# Optional
|
||||
vram_limit: float = None,
|
||||
):
|
||||
pipe = JoyAIImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
pipe.text_encoder = model_pool.fetch_model("joyai_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("joyai_image_dit")
|
||||
pipe.vae = model_pool.fetch_model("wan_video_vae")
|
||||
|
||||
if processor_config is not None:
|
||||
processor_config.download_if_necessary()
|
||||
from transformers import AutoProcessor
|
||||
pipe.processor = AutoProcessor.from_pretrained(processor_config.path)
|
||||
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 5.0,
|
||||
# Image
|
||||
edit_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
# Steps
|
||||
max_sequence_length: int = 4096,
|
||||
num_inference_steps: int = 30,
|
||||
# Tiling
|
||||
tiled: Optional[bool] = False,
|
||||
tile_size: Optional[tuple[int, int]] = (30, 52),
|
||||
tile_stride: Optional[tuple[int, int]] = (15, 26),
|
||||
# Scheduler
|
||||
shift: Optional[float] = 4.0,
|
||||
# Progress bar
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=shift)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"edit_image": edit_image,
|
||||
"denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "max_sequence_length": max_sequence_length,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
}
|
||||
|
||||
# Unit chain
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||
)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
latents = rearrange(inputs_shared["latents"], "b n c f h w -> (b n) c f h w")
|
||||
image = self.vae.decode(latents, device=self.device)[0]
|
||||
image = self.vae_output_to_image(image, pattern="C 1 H W")
|
||||
self.load_models_to_device([])
|
||||
return image
|
||||
|
||||
|
||||
class JoyAIImageUnit_ShapeChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width"),
|
||||
output_params=("height", "width"),
|
||||
)
|
||||
|
||||
def process(self, pipe: "JoyAIImagePipeline", height, width):
|
||||
height, width = pipe.check_resize_height_width(height, width)
|
||||
return {"height": height, "width": width}
|
||||
|
||||
|
||||
class JoyAIImageUnit_PromptEmbedder(PipelineUnit):
|
||||
prompt_template_encode = {
|
||||
'image':
|
||||
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
||||
'multiple_images':
|
||||
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n{}<|im_start|>assistant\n",
|
||||
'video':
|
||||
"<|im_start|>system\n \\nDescribe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
}
|
||||
prompt_template_encode_start_idx = {'image': 34, 'multiple_images': 34, 'video': 91}
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt", "positive": "positive"},
|
||||
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
|
||||
input_params=("edit_image", "max_sequence_length"),
|
||||
output_params=("prompt_embeds", "prompt_embeds_mask"),
|
||||
onload_model_names=("joyai_image_text_encoder",),
|
||||
)
|
||||
|
||||
def process(self, pipe: "JoyAIImagePipeline", prompt, positive, edit_image, max_sequence_length):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
|
||||
has_image = edit_image is not None
|
||||
|
||||
if has_image:
|
||||
prompt_embeds, prompt_embeds_mask = self._encode_with_image(pipe, prompt, edit_image, max_sequence_length)
|
||||
else:
|
||||
prompt_embeds, prompt_embeds_mask = self._encode_text_only(pipe, prompt, max_sequence_length)
|
||||
|
||||
return {"prompt_embeds": prompt_embeds, "prompt_embeds_mask": prompt_embeds_mask}
|
||||
|
||||
def _encode_with_image(self, pipe, prompt, edit_image, max_sequence_length):
|
||||
template = self.prompt_template_encode['multiple_images']
|
||||
drop_idx = self.prompt_template_encode_start_idx['multiple_images']
|
||||
|
||||
image_tokens = '<image>\n'
|
||||
prompt = f"<|im_start|>user\n{image_tokens}{prompt}<|im_end|>\n"
|
||||
prompt = prompt.replace('<image>\n', '<|vision_start|><|image_pad|><|vision_end|>')
|
||||
prompt = template.format(prompt)
|
||||
inputs = pipe.processor(text=[prompt], images=[edit_image], padding=True, return_tensors="pt").to(pipe.device)
|
||||
last_hidden_states = pipe.text_encoder(**inputs)
|
||||
|
||||
prompt_embeds = last_hidden_states[:, drop_idx:]
|
||||
prompt_embeds_mask = inputs['attention_mask'][:, drop_idx:]
|
||||
|
||||
if max_sequence_length is not None and prompt_embeds.shape[1] > max_sequence_length:
|
||||
prompt_embeds = prompt_embeds[:, -max_sequence_length:, :]
|
||||
prompt_embeds_mask = prompt_embeds_mask[:, -max_sequence_length:]
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def _encode_text_only(self, pipe, prompt, max_sequence_length):
|
||||
# TODO: may support for text-only encoding in the future.
|
||||
raise NotImplementedError("Text-only encoding is not implemented yet. Please provide edit_image for now.")
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
|
||||
class JoyAIImageUnit_EditImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image", "tiled", "tile_size", "tile_stride", "height", "width"),
|
||||
output_params=("ref_latents", "num_items", "is_multi_item"),
|
||||
onload_model_names=("wan_video_vae",),
|
||||
)
|
||||
|
||||
def process(self, pipe: "JoyAIImagePipeline", edit_image, tiled, tile_size, tile_stride, height, width):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
# Resize edit image to match target dimensions (from ShapeChecker) to ensure ref_latents matches latents
|
||||
edit_image = edit_image.resize((width, height), Image.LANCZOS)
|
||||
images = [pipe.preprocess_image(edit_image).transpose(0, 1)]
|
||||
latents = pipe.vae.encode(images, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
ref_vae = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=1).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
|
||||
return {"ref_latents": ref_vae, "edit_image": edit_image}
|
||||
|
||||
|
||||
class JoyAIImageUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("seed", "height", "width", "rand_device"),
|
||||
output_params=("noise"),
|
||||
)
|
||||
|
||||
def process(self, pipe: "JoyAIImagePipeline", seed, height, width, rand_device):
|
||||
latent_h = height // pipe.vae.upsampling_factor
|
||||
latent_w = width // pipe.vae.upsampling_factor
|
||||
shape = (1, 1, pipe.vae.z_dim, 1, latent_h, latent_w)
|
||||
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class JoyAIImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",),
|
||||
)
|
||||
|
||||
def process(self, pipe: JoyAIImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
|
||||
if input_image is None:
|
||||
return {"latents": noise}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if isinstance(input_image, Image.Image):
|
||||
input_image = [input_image]
|
||||
input_image = [pipe.preprocess_image(img).transpose(0, 1) for img in input_image]
|
||||
latents = pipe.vae.encode(input_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
input_latents = rearrange(latents, "(b n) c 1 h w -> b n c 1 h w", n=(len(input_image)))
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
|
||||
def model_fn_joyai_image(
|
||||
dit,
|
||||
latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
prompt_embeds_mask,
|
||||
ref_latents=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
img = torch.cat([ref_latents, latents], dim=1) if ref_latents is not None else latents
|
||||
|
||||
img = dit(
|
||||
hidden_states=img,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
img = img[:, -latents.size(1):]
|
||||
return img
|
||||
@@ -99,7 +99,6 @@ def save_audio(waveform: torch.Tensor, sample_rate: int, save_path: str, backend
|
||||
"""
|
||||
if waveform.dim() == 3:
|
||||
waveform = waveform[0]
|
||||
waveform.cpu()
|
||||
|
||||
if backend == "torchcodec":
|
||||
from torchcodec.encoders import AudioEncoder
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
def AceStepConditionEncoderStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
prefix = "encoder."
|
||||
|
||||
for key in state_dict:
|
||||
if key.startswith(prefix):
|
||||
new_key = key[len(prefix):]
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
|
||||
if "null_condition_emb" in state_dict:
|
||||
new_state_dict["null_condition_emb"] = state_dict["null_condition_emb"]
|
||||
|
||||
return new_state_dict
|
||||
@@ -1,10 +0,0 @@
|
||||
def AceStepDiTModelStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
prefix = "decoder."
|
||||
|
||||
for key in state_dict:
|
||||
if key.startswith(prefix):
|
||||
new_key = key[len(prefix):]
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
|
||||
return new_state_dict
|
||||
@@ -1,15 +0,0 @@
|
||||
def AceStepTextEncoderStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
prefix = "model."
|
||||
nested_prefix = "model.model."
|
||||
|
||||
for key in state_dict:
|
||||
if key.startswith(nested_prefix):
|
||||
new_key = key
|
||||
elif key.startswith(prefix):
|
||||
new_key = "model." + key
|
||||
else:
|
||||
new_key = "model." + key
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
|
||||
return new_state_dict
|
||||
@@ -1,8 +0,0 @@
|
||||
def AceStepTokenizerStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
|
||||
for key in state_dict:
|
||||
if key.startswith("tokenizer.") or key.startswith("detokenizer."):
|
||||
new_state_dict[key] = state_dict[key]
|
||||
|
||||
return new_state_dict
|
||||
@@ -1,21 +0,0 @@
|
||||
def ErnieImageTextEncoderStateDictConverter(state_dict):
|
||||
"""
|
||||
Maps checkpoint keys from multimodal Mistral3Model format
|
||||
to text-only Ministral3Model format.
|
||||
|
||||
Checkpoint keys (Mistral3Model):
|
||||
language_model.model.layers.0.input_layernorm.weight
|
||||
language_model.model.norm.weight
|
||||
|
||||
Model keys (ErnieImageTextEncoder → self.model = Ministral3Model):
|
||||
model.layers.0.input_layernorm.weight
|
||||
model.norm.weight
|
||||
|
||||
Mapping: language_model. → model.
|
||||
"""
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
if key.startswith("language_model.model."):
|
||||
new_key = key.replace("language_model.model.", "model.", 1)
|
||||
new_state_dict[new_key] = state_dict[key]
|
||||
return new_state_dict
|
||||
@@ -1,20 +0,0 @@
|
||||
def JoyAIImageTextEncoderStateDictConverter(state_dict):
|
||||
"""Convert HuggingFace Qwen3VL checkpoint keys to DiffSynth wrapper keys.
|
||||
|
||||
Mapping (checkpoint -> wrapper):
|
||||
- lm_head.weight -> model.lm_head.weight
|
||||
- model.language_model.* -> model.model.language_model.*
|
||||
- model.visual.* -> model.model.visual.*
|
||||
"""
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if key == "lm_head.weight":
|
||||
new_key = "model.lm_head.weight"
|
||||
elif key.startswith("model.language_model."):
|
||||
new_key = "model.model." + key[len("model."):]
|
||||
elif key.startswith("model.visual."):
|
||||
new_key = "model.model." + key[len("model."):]
|
||||
else:
|
||||
new_key = key
|
||||
state_dict_[new_key] = state_dict[key]
|
||||
return state_dict_
|
||||
@@ -1,164 +0,0 @@
|
||||
# ACE-Step
|
||||
|
||||
ACE-Step 1.5 is an open-source music generation model based on DiT architecture, supporting text-to-music, audio cover, repainting and other functionalities, running efficiently on consumer-grade hardware.
|
||||
|
||||
## Installation
|
||||
|
||||
Before performing model inference and training, please install DiffSynth-Studio first.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Running the following code will load the [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 3GB VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|
||||
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|
||||
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|
||||
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|
||||
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|
||||
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|
||||
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|
||||
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
|
||||
|
||||
## Model Inference
|
||||
|
||||
The model is loaded via `AceStepPipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||
|
||||
The input parameters for `AceStepPipeline` inference include:
|
||||
|
||||
* `prompt`: Text description of the music.
|
||||
* `cfg_scale`: Classifier-free guidance scale, defaults to 1.0.
|
||||
* `lyrics`: Lyrics text.
|
||||
* `task_type`: Task type,可选 values include `"text2music"` (text-to-music), `"cover"` (audio cover), `"repaint"` (repainting), defaults to `"text2music"`.
|
||||
* `reference_audios`: List of reference audio tensors for timbre reference.
|
||||
* `src_audio`: Source audio tensor for cover or repaint tasks.
|
||||
* `denoising_strength`: Denoising strength, controlling how much the output is influenced by source audio, defaults to 1.0.
|
||||
* `audio_cover_strength`: Audio cover step ratio, controlling how many steps use cover condition in cover tasks, defaults to 1.0.
|
||||
* `audio_code_string`: Input audio code string for cover tasks with discrete audio codes.
|
||||
* `repainting_ranges`: List of repainting time ranges (tuples of floats, in seconds) for repaint tasks.
|
||||
* `repainting_strength`: Repainting intensity, controlling the degree of change in repainted areas, defaults to 1.0.
|
||||
* `duration`: Audio duration in seconds, defaults to 60.
|
||||
* `bpm`: Beats per minute, defaults to 100.
|
||||
* `keyscale`: Musical key scale, defaults to "B minor".
|
||||
* `timesignature`: Time signature, defaults to "4".
|
||||
* `vocal_language`: Vocal language, defaults to "unknown".
|
||||
* `seed`: Random seed.
|
||||
* `rand_device`: Device for noise generation, defaults to "cpu".
|
||||
* `num_inference_steps`: Number of inference steps, defaults to 8.
|
||||
* `shift`: Timestep shift parameter for the scheduler, defaults to 1.0.
|
||||
|
||||
## Model Training
|
||||
|
||||
Models in the ace_step series are trained uniformly via `examples/ace_step/model_training/train.py`. The script parameters include:
|
||||
|
||||
* General Training Parameters
|
||||
* Dataset Configuration
|
||||
* `--dataset_base_path`: Root directory of the dataset.
|
||||
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||
* `--dataset_repeat`: Number of dataset repeats per epoch.
|
||||
* `--dataset_num_workers`: Number of processes per DataLoader.
|
||||
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
|
||||
* Model Loading Configuration
|
||||
* `--model_paths`: Paths to load models from, in JSON format.
|
||||
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
|
||||
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
|
||||
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
|
||||
* Basic Training Configuration
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
|
||||
* `--weight_decay`: Weight decay magnitude.
|
||||
* `--task`: Training task, defaults to `sft`.
|
||||
* Output Configuration
|
||||
* `--output_path`: Path to save the model.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
|
||||
* `--save_steps`: Interval in training steps to save the model.
|
||||
* LoRA Configuration
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path to LoRA checkpoint.
|
||||
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
|
||||
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
|
||||
* Gradient Configuration
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Resolution Configuration
|
||||
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
|
||||
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
|
||||
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
|
||||
* `--num_frames`: Number of frames for video (video generation models only).
|
||||
* ACE-Step Specific Parameters
|
||||
* `--tokenizer_path`: Tokenizer path, in format model_id:origin_pattern.
|
||||
* `--silence_latent_path`: Silence latent path, in format model_id:origin_pattern.
|
||||
* `--initialize_model_on_cpu`: Whether to initialize models on CPU.
|
||||
|
||||
### Example Dataset
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||
```
|
||||
|
||||
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||
@@ -1,134 +0,0 @@
|
||||
# ERNIE-Image
|
||||
|
||||
ERNIE-Image is a powerful image generation model with 8B parameters developed by Baidu, featuring a compact and efficient architecture with strong instruction-following capability. Based on an 8B DiT backbone, it delivers performance comparable to larger (20B+) models in certain scenarios while maintaining parameter efficiency. It offers reliable performance in instruction understanding and execution, text generation (English/Chinese/Japanese), and overall stability.
|
||||
|
||||
## Installation
|
||||
|
||||
Before performing model inference and training, please install DiffSynth-Studio first.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Running the following code will load the [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 3G VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = ErnieImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device='cuda',
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="一只黑白相间的中华田园犬",
|
||||
negative_prompt="",
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=42,
|
||||
num_inference_steps=50,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
image.save("output.jpg")
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
||||
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
||||
|
||||
## Model Inference
|
||||
|
||||
The model is loaded via `ErnieImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||
|
||||
The input parameters for `ErnieImagePipeline` inference include:
|
||||
|
||||
* `prompt`: The prompt describing the content to appear in the image.
|
||||
* `negative_prompt`: The negative prompt describing what should not appear in the image, default value is `""`.
|
||||
* `cfg_scale`: Classifier-free guidance parameter, default value is 4.0.
|
||||
* `height`: Image height, must be a multiple of 16, default value is 1024.
|
||||
* `width`: Image width, must be a multiple of 16, default value is 1024.
|
||||
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||
* `rand_device`: The computing device for generating random Gaussian noise matrices, default is `"cuda"`. When set to `cuda`, different GPUs will produce different results.
|
||||
* `num_inference_steps`: Number of inference steps, default value is 50.
|
||||
|
||||
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low-VRAM configurations for each model in the "Model Overview" table above.
|
||||
|
||||
## Model Training
|
||||
|
||||
ERNIE-Image series models are trained uniformly via [`examples/ernie_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/train.py). The script parameters include:
|
||||
|
||||
* General Training Parameters
|
||||
* Dataset Configuration
|
||||
* `--dataset_base_path`: Root directory of the dataset.
|
||||
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||
* `--dataset_repeat`: Number of dataset repeats per epoch.
|
||||
* `--dataset_num_workers`: Number of processes per DataLoader.
|
||||
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
|
||||
* Model Loading Configuration
|
||||
* `--model_paths`: Paths to load models from, in JSON format.
|
||||
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"PaddlePaddle/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors"`, separated by commas.
|
||||
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
|
||||
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
|
||||
* Basic Training Configuration
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
|
||||
* `--weight_decay`: Weight decay magnitude.
|
||||
* `--task`: Training task, defaults to `sft`.
|
||||
* Output Configuration
|
||||
* `--output_path`: Path to save the model.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
|
||||
* `--save_steps`: Interval in training steps to save the model.
|
||||
* LoRA Configuration
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path to LoRA checkpoint.
|
||||
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
|
||||
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
|
||||
* Gradient Configuration
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Resolution Configuration
|
||||
* `--height`: Height of the image. Leave empty to enable dynamic resolution.
|
||||
* `--width`: Width of the image. Leave empty to enable dynamic resolution.
|
||||
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
|
||||
* ERNIE-Image Specific Parameters
|
||||
* `--tokenizer_path`: Path to the tokenizer, leave empty to auto-download from remote.
|
||||
|
||||
We provide an example image dataset for testing, which can be downloaded with the following command:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||
```
|
||||
|
||||
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||
@@ -1,154 +0,0 @@
|
||||
# JoyAI-Image
|
||||
|
||||
JoyAI-Image is a unified multi-modal foundation model open-sourced by JD.com, supporting image understanding, text-to-image generation, and instruction-guided image editing.
|
||||
|
||||
## Installation
|
||||
|
||||
Before performing model inference and training, please install DiffSynth-Studio first.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For more information on installation, please refer to [Setup Dependencies](../Pipeline_Usage/Setup.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Running the following code will load the [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) model for inference. VRAM management is enabled, the framework automatically controls parameter loading based on available VRAM, requiring a minimum of 4GB VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
# Download dataset
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||
local_dir="data/diffsynth_example_dataset",
|
||||
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||
)
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
pipe = JoyAIImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
# Use first sample from dataset
|
||||
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||
prompt = "将裙子改为粉色"
|
||||
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
edit_image=edit_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=0,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=5.0,
|
||||
)
|
||||
|
||||
output.save("output_joyai_edit_low_vram.png")
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Full Training Validation|LoRA Training|LoRA Training Validation|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
|
||||
|
||||
## Model Inference
|
||||
|
||||
The model is loaded via `JoyAIImagePipeline.from_pretrained`, see [Loading Models](../Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||
|
||||
The input parameters for `JoyAIImagePipeline` inference include:
|
||||
|
||||
* `prompt`: Text prompt describing the desired image editing effect.
|
||||
* `negative_prompt`: Negative prompt specifying what should not appear in the result, defaults to empty string.
|
||||
* `cfg_scale`: Classifier-free guidance scale factor, defaults to 5.0. Higher values make the output more closely follow the prompt.
|
||||
* `edit_image`: Image to be edited.
|
||||
* `denoising_strength`: Denoising strength controlling how much the input image is repainted, defaults to 1.0.
|
||||
* `height`: Height of the output image, defaults to 1024. Must be divisible by 16.
|
||||
* `width`: Width of the output image, defaults to 1024. Must be divisible by 16.
|
||||
* `seed`: Random seed for reproducibility. Set to `None` for random seed.
|
||||
* `max_sequence_length`: Maximum sequence length for the text encoder, defaults to 4096.
|
||||
* `num_inference_steps`: Number of inference steps, defaults to 30. More steps typically yield better quality.
|
||||
* `tiled`: Whether to enable tiling for reduced VRAM usage, defaults to False.
|
||||
* `tile_size`: Tile size, defaults to (30, 52).
|
||||
* `tile_stride`: Tile stride, defaults to (15, 26).
|
||||
* `shift`: Shift parameter for the scheduler, controlling the Flow Match scheduling curve, defaults to 4.0.
|
||||
* `progress_bar_cmd`: Progress bar display mode, defaults to tqdm.
|
||||
|
||||
## Model Training
|
||||
|
||||
Models in the joyai_image series are trained uniformly via `examples/joyai_image/model_training/train.py`. The script parameters include:
|
||||
|
||||
* General Training Parameters
|
||||
* Dataset Configuration
|
||||
* `--dataset_base_path`: Root directory of the dataset.
|
||||
* `--dataset_metadata_path`: Path to the dataset metadata file.
|
||||
* `--dataset_repeat`: Number of dataset repeats per epoch.
|
||||
* `--dataset_num_workers`: Number of processes per DataLoader.
|
||||
* `--data_file_keys`: Field names to load from metadata, typically paths to image or video files, separated by `,`.
|
||||
* Model Loading Configuration
|
||||
* `--model_paths`: Paths to load models from, in JSON format.
|
||||
* `--model_id_with_origin_paths`: Model IDs with original paths, separated by commas.
|
||||
* `--extra_inputs`: Additional input parameters required by the model Pipeline, separated by `,`.
|
||||
* `--fp8_models`: Models to load in FP8 format, currently only supported for models whose parameters are not updated by gradients.
|
||||
* Basic Training Configuration
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||
* `--find_unused_parameters`: Whether unused parameters exist in DDP training.
|
||||
* `--weight_decay`: Weight decay magnitude.
|
||||
* `--task`: Training task, defaults to `sft`.
|
||||
* Output Configuration
|
||||
* `--output_path`: Path to save the model.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in the model's state dict.
|
||||
* `--save_steps`: Interval in training steps to save the model.
|
||||
* LoRA Configuration
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path to LoRA checkpoint.
|
||||
* `--preset_lora_path`: Path to preset LoRA checkpoint for LoRA differential training.
|
||||
* `--preset_lora_model`: Which model to integrate preset LoRA into, e.g., `dit`.
|
||||
* Gradient Configuration
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Resolution Configuration
|
||||
* `--height`: Height of the image/video. Leave empty to enable dynamic resolution.
|
||||
* `--width`: Width of the image/video. Leave empty to enable dynamic resolution.
|
||||
* `--max_pixels`: Maximum pixel area, images larger than this will be scaled down during dynamic resolution.
|
||||
* `--num_frames`: Number of frames for video (video generation models only).
|
||||
* JoyAI-Image Specific Parameters
|
||||
* `--processor_path`: Path to the processor for processing text and image encoder inputs.
|
||||
* `--initialize_model_on_cpu`: Whether to initialize models on CPU. By default, models are initialized on the accelerator device.
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||
```
|
||||
|
||||
We provide recommended training scripts for each model, please refer to the table in "Model Overview" above. For guidance on writing model training scripts, see [Model Training](../Pipeline_Usage/Model_Training.md); for more advanced training algorithms, see [Training Framework Overview](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/en/Training/).
|
||||
@@ -1,94 +0,0 @@
|
||||
# Inference Acceleration
|
||||
|
||||
The denoising process of diffusion models is typically time-consuming. To improve inference speed, various acceleration techniques can be applied, including lossless acceleration solutions such as multi-GPU parallel inference and computation graph compilation, as well as lossy acceleration solutions like Cache and quantization.
|
||||
|
||||
Currently, most diffusion models are built on Diffusion Transformers (DiT), and efficient attention mechanisms are also a common acceleration method. DiffSynth-Studio currently supports certain lossless acceleration inference features. This section focuses on introducing acceleration methods from two dimensions: multi-GPU parallel inference and computation graph compilation.
|
||||
|
||||
## Efficient Attention Mechanisms
|
||||
|
||||
For details on the acceleration of attention mechanisms, please refer to [Attention Mechanism Implementation](../API_Reference/core/attention.md).
|
||||
|
||||
## Multi-GPU Parallel Inference
|
||||
|
||||
DiffSynth-Studio adopts a multi-GPU inference solution using Unified Sequence Parallel (USP). It splits the token sequence in the DiT across multiple GPUs for parallel processing. The underlying implementation is based on [xDiT](https://github.com/xdit-project/xDiT). Please note that unified sequence parallelism introduces additional communication overhead, so the actual speedup ratio is usually lower than the number of GPUs.
|
||||
|
||||
Currently, DiffSynth-Studio supports unified sequence parallel acceleration for the [Wan](../Model_Details/Wan.md) and [MOVA](../Model_Details/Wan.md) models.
|
||||
|
||||
First, install the `xDiT` dependency.
|
||||
|
||||
```bash
|
||||
pip install "xfuser[flash-attn]>=0.4.3"
|
||||
```
|
||||
|
||||
Then, use `torchrun` to launch multi-GPU inference.
|
||||
|
||||
```bash
|
||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
||||
```
|
||||
|
||||
When building the pipeline, simply configure `use_usp=True` to enable USP parallel inference. A code example is shown below.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
import torch.distributed as dist
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
use_usp=True,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||
)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
## Computation Graph Compilation
|
||||
|
||||
PyTorch 2.0 provides an automatic computation graph compilation interface, [torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), which can just-in-time (JIT) compile PyTorch code into optimized kernels, thereby improving execution speed. Since the inference time of diffusion models is concentrated in the multi-step denoising phase of the DiT, and the DiT is primarily stacked with basic blocks, DiffSynth's compile feature uses a [regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) strategy targeting only the basic Transformer blocks to reduce compilation time.
|
||||
|
||||
### Compile Usage Example
|
||||
|
||||
Compared to standard inference, you simply need to execute `pipe.compile_pipeline()` before calling the pipeline to enable compilation acceleration. For the specific function definition, please refer to the [source code](https://github.com/modelscope/DiffSynth-Studio/blob/166e6d2d38764209f66a74dd0fe468226536ad0f/diffsynth/diffusion/base_pipeline.py#L342).
|
||||
|
||||
The input parameters for `compile_pipeline` consist mainly of two types.
|
||||
|
||||
The first type is the compiled model parameters, `compile_models`. Taking the Qwen-Image Pipeline as an example, if you only want to compile the DiT model, you can keep this parameter empty. If you need to additionally compile models like the VAE, you can pass `compile_models=["vae", "dit"]`. Aside from DiT, all other models use a full-graph compilation strategy, meaning the model's forward function is completely compiled into a computation graph.
|
||||
|
||||
The second type is the compilation strategy parameters. This covers `mode`, `dynamic`, `fullgraph`, and other custom options. These parameters are directly passed to the `torch.compile` interface. If you are not deeply familiar with the specific mechanics of these parameters, it is recommended to keep the default settings.
|
||||
|
||||
* `mode` specifies the compilation mode, including `"default"`, `"reduce-overhead"`, `"max-autotune"`, and `"max-autotune-no-cudagraphs"`. Because cudagraph has stricter requirements on computation graphs (for example, it might need to be used in conjunction with `torch.compiler.cudagraph_mark_step_begin()`), the `"reduce-overhead"` and `"max-autotune"` modes might fail to compile.
|
||||
* `dynamic` determines whether to enable dynamic shapes. For most generative models, modifying the prompt, enabling CFG, or adjusting the resolution will change the shape of the input tensors to the computation graph. Setting `dynamic=True` will increase the compilation time of the first run, but it supports dynamic shapes, meaning no recompilation is needed when shapes change. When set to `dynamic=False`, the first compilation is faster, but any operation that alters the input shape will trigger a recompilation. For most scenarios, setting it to `dynamic=True` is recommended.
|
||||
* `fullgraph`, when set to `True`, makes the underlying system attempt to compile the target model into a single computation graph, throwing an error if it fails. When set to `False`, the underlying system will set breakpoints where connections cannot be made, compiling the model into multiple independent computation graphs. Developers can set it to `True` to optimize compilation performance, but regular users are advised to only use `False`.
|
||||
* For other parameter configurations, please consult the [API documentation](https://docs.pytorch.org/docs/stable/generated/torch.compile.html).
|
||||
|
||||
### Compile Feature Developer Documentation
|
||||
|
||||
If you need to provide compile support for a newly integrated pipeline, you should configure the `compilable_models` attribute in the pipeline to specify the default models to compile. For the DiT model class of that pipeline, you also need to configure `_repeated_blocks` to specify the types of basic blocks that will participate in regional compilation.
|
||||
|
||||
Taking Qwen-Image as an example, its pipeline configuration is as follows:
|
||||
|
||||
```python
|
||||
self.compilable_models = ["dit"]
|
||||
```
|
||||
|
||||
Its DiT configuration is as follows:
|
||||
|
||||
```python
|
||||
class QwenImageDiT(torch.nn.Module):
|
||||
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||
```
|
||||
@@ -13,7 +13,6 @@ Welcome to DiffSynth-Studio's Documentation
|
||||
|
||||
Pipeline_Usage/Setup
|
||||
Pipeline_Usage/Model_Inference
|
||||
Pipeline_Usage/Accelerated_Inference
|
||||
Pipeline_Usage/VRAM_management
|
||||
Pipeline_Usage/Model_Training
|
||||
Pipeline_Usage/Environment_Variables
|
||||
@@ -30,9 +29,6 @@ Welcome to DiffSynth-Studio's Documentation
|
||||
Model_Details/Z-Image
|
||||
Model_Details/Anima
|
||||
Model_Details/LTX-2
|
||||
Model_Details/ERNIE-Image
|
||||
Model_Details/JoyAI-Image
|
||||
Model_Details/ACE-Step
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
# ACE-Step
|
||||
|
||||
ACE-Step 1.5 是一个开源音乐生成模型,基于 DiT 架构,支持文生音乐、音频翻唱、局部重绘等多种功能,可在消费级硬件上高效运行。
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
运行以下代码可以快速加载 [ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[ACE-Step/Ace-Step1.5](https://www.modelscope.cn/models/ACE-Step/Ace-Step1.5)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/Ace-Step1.5.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/Ace-Step1.5.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/Ace-Step1.5.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-shift1](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift1)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift1.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift1.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift1.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-shift3](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-shift3)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-shift3.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-shift3.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-shift3.py)|
|
||||
|[ACE-Step/acestep-v15-turbo-continuous](https://www.modelscope.cn/models/ACE-Step/acestep-v15-turbo-continuous)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-turbo-continuous.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-turbo-continuous.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-turbo-continuous.py)|
|
||||
|[ACE-Step/acestep-v15-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-base.py)|
|
||||
|[ACE-Step/acestep-v15-base: CoverTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-CoverTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-CoverTask.py)|—|—|—|—|
|
||||
|[ACE-Step/acestep-v15-base: RepaintTask](https://www.modelscope.cn/models/ACE-Step/acestep-v15-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-base-RepaintTask.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-base-RepaintTask.py)|—|—|—|—|
|
||||
|[ACE-Step/acestep-v15-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-sft.py)|
|
||||
|[ACE-Step/acestep-v15-xl-base](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-base)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-base.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-base.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-base.py)|
|
||||
|[ACE-Step/acestep-v15-xl-sft](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-sft)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-sft.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-sft.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-sft.py)|
|
||||
|[ACE-Step/acestep-v15-xl-turbo](https://www.modelscope.cn/models/ACE-Step/acestep-v15-xl-turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_inference_low_vram/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/full/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_full/acestep-v15-xl-turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/lora/acestep-v15-xl-turbo.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ace_step/model_training/validate_lora/acestep-v15-xl-turbo.py)|
|
||||
|
||||
## 模型推理
|
||||
|
||||
模型通过 `AceStepPipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||
|
||||
`AceStepPipeline` 推理的输入参数包括:
|
||||
|
||||
* `prompt`: 音乐文本描述。
|
||||
* `cfg_scale`: 分类器无条件引导比例,默认为 1.0。
|
||||
* `lyrics`: 歌词文本。
|
||||
* `task_type`: 任务类型,可选值包括 `"text2music"`(文生音乐)、`"cover"`(音频翻唱)、`"repaint"`(局部重绘),默认为 `"text2music"`。
|
||||
* `reference_audios`: 参考音频列表(Tensor 列表),用于提供音色参考。
|
||||
* `src_audio`: 源音频(Tensor),用于 cover 或 repaint 任务。
|
||||
* `denoising_strength`: 降噪强度,控制输出受源音频的影响程度,默认为 1.0。
|
||||
* `audio_cover_strength`: 音频翻唱步数比例,控制 cover 任务中前多少步使用翻唱条件,默认为 1.0。
|
||||
* `audio_code_string`: 输入音频码字符串,用于 cover 任务中直接传入离散音频码。
|
||||
* `repainting_ranges`: 重绘时间区间(浮点元组列表,单位为秒),用于 repaint 任务。
|
||||
* `repainting_strength`: 重绘强度,控制重绘区域的变化程度,默认为 1.0。
|
||||
* `duration`: 音频时长(秒),默认为 60。
|
||||
* `bpm`: 每分钟节拍数,默认为 100。
|
||||
* `keyscale`: 音阶调式,默认为 "B minor"。
|
||||
* `timesignature`: 拍号,默认为 "4"。
|
||||
* `vocal_language`: 演唱语言,默认为 "unknown"。
|
||||
* `seed`: 随机种子。
|
||||
* `rand_device`: 噪声生成设备,默认为 "cpu"。
|
||||
* `num_inference_steps`: 推理步数,默认为 8。
|
||||
* `shift`: 调度器时间偏移参数,默认为 1.0。
|
||||
|
||||
## 模型训练
|
||||
|
||||
ace_step 系列模型统一通过 `examples/ace_step/model_training/train.py` 进行训练,脚本的参数包括:
|
||||
|
||||
* 通用训练参数
|
||||
* 数据集基础配置
|
||||
* `--dataset_base_path`: 数据集的根目录。
|
||||
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
|
||||
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||
* 模型加载配置
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
|
||||
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
|
||||
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
|
||||
* 训练基础配置
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
|
||||
* `--weight_decay`: 权重衰减大小。
|
||||
* `--task`: 训练任务,默认为 `sft`。
|
||||
* 输出配置
|
||||
* `--output_path`: 模型保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||
* `--save_steps`: 保存模型的训练步数间隔。
|
||||
* LoRA 配置
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。
|
||||
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
|
||||
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||
* 梯度配置
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 分辨率配置
|
||||
* `--height`: 图像/视频的高度。留空启用动态分辨率。
|
||||
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
|
||||
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
|
||||
* `--num_frames`: 视频的帧数(仅视频生成模型)。
|
||||
* ACE-Step 专有参数
|
||||
* `--tokenizer_path`: Tokenizer 路径,格式为 model_id:origin_pattern。
|
||||
* `--silence_latent_path`: 静音隐变量路径,格式为 model_id:origin_pattern。
|
||||
* `--initialize_model_on_cpu`: 是否在 CPU 上初始化模型。
|
||||
|
||||
### 样例数据集
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||
```
|
||||
|
||||
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||
@@ -1,134 +0,0 @@
|
||||
# ERNIE-Image
|
||||
|
||||
ERNIE-Image 是百度推出的拥有 8B 参数的图像生成模型,具有紧凑高效的架构和出色的指令跟随能力。基于 8B DiT 主干网络,其在某些场景下的性能可与 20B 以上的更大模型相媲美,同时保持了良好的参数效率。该模型在指令理解与执行、文本生成(如英文/中文/日文)以及整体稳定性方面提供了较为可靠的表现。
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
运行以下代码可以快速加载 [PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 3G 显存即可运行。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = ErnieImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device='cuda',
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="一只黑白相间的中华田园犬",
|
||||
negative_prompt="",
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=42,
|
||||
num_inference_steps=50,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
image.save("output.jpg")
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[PaddlePaddle/ERNIE-Image](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/full/ERNIE-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_full/ERNIE-Image.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/lora/ERNIE-Image.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/validate_lora/ERNIE-Image.py)|
|
||||
|[PaddlePaddle/ERNIE-Image-Turbo](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-Image-Turbo)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference/ERNIE-Image-Turbo.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_inference_low_vram/ERNIE-Image-Turbo.py)|—|—|—|—|
|
||||
|
||||
## 模型推理
|
||||
|
||||
模型通过 `ErnieImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||
|
||||
`ErnieImagePipeline` 推理的输入参数包括:
|
||||
|
||||
* `prompt`: 提示词,描述画面中出现的内容。
|
||||
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
|
||||
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 4.0。
|
||||
* `height`: 图像高度,需保证高度为 16 的倍数,默认值为 1024。
|
||||
* `width`: 图像宽度,需保证宽度为 16 的倍数,默认值为 1024。
|
||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cuda"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||
* `num_inference_steps`: 推理步数,默认值为 50。
|
||||
|
||||
如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
|
||||
|
||||
## 模型训练
|
||||
|
||||
ERNIE-Image 系列模型统一通过 [`examples/ernie_image/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ernie_image/model_training/train.py) 进行训练,脚本的参数包括:
|
||||
|
||||
* 通用训练参数
|
||||
* 数据集基础配置
|
||||
* `--dataset_base_path`: 数据集的根目录。
|
||||
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
|
||||
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||
* 模型加载配置
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"PaddlePaddle/ERNIE-Image:transformer/diffusion_pytorch_model*.safetensors"`。用逗号分隔。
|
||||
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
|
||||
* `--fp8_models`:以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
|
||||
* 训练基础配置
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
|
||||
* `--weight_decay`:权重衰减大小。
|
||||
* `--task`: 训练任务,默认为 `sft`。
|
||||
* 输出配置
|
||||
* `--output_path`: 模型保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||
* `--save_steps`: 保存模型的训练步数间隔。
|
||||
* LoRA 配置
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。
|
||||
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
|
||||
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||
* 梯度配置
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 分辨率配置
|
||||
* `--height`: 图像的高度。留空启用动态分辨率。
|
||||
* `--width`: 图像的宽度。留空启用动态分辨率。
|
||||
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
|
||||
* ERNIE-Image 专有参数
|
||||
* `--tokenizer_path`: tokenizer 的路径,留空则自动从远程下载。
|
||||
|
||||
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||
```
|
||||
|
||||
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||
@@ -1,154 +0,0 @@
|
||||
# JoyAI-Image
|
||||
|
||||
JoyAI-Image 是京东开源的统一多模态基础模型,支持图像理解、文生图生成和指令引导的图像编辑。
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
更多关于安装的信息,请参考[安装依赖](../Pipeline_Usage/Setup.md)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
运行以下代码可以快速加载 [jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 4G 显存即可运行。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.joyai_image import JoyAIImagePipeline, ModelConfig
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
# Download dataset
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||
local_dir="data/diffsynth_example_dataset",
|
||||
allow_file_pattern="joyai_image/JoyAI-Image-Edit/*"
|
||||
)
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
pipe = JoyAIImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="transformer/transformer.pth", **vram_config),
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="vae/Wan2.1_VAE.pth", **vram_config),
|
||||
],
|
||||
processor_config=ModelConfig(model_id="jd-opensource/JoyAI-Image-Edit", origin_file_pattern="JoyAI-Image-Und/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
# Use first sample from dataset
|
||||
dataset_base_path = "data/diffsynth_example_dataset/joyai_image/JoyAI-Image-Edit"
|
||||
prompt = "将裙子改为粉色"
|
||||
edit_image = Image.open(f"{dataset_base_path}/edit/image1.jpg").convert("RGB")
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
edit_image=edit_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=0,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=5.0,
|
||||
)
|
||||
|
||||
output.save("output_joyai_edit_low_vram.png")
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[jd-opensource/JoyAI-Image-Edit](https://modelscope.cn/models/jd-opensource/JoyAI-Image-Edit)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_inference_low_vram/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/full/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_full/JoyAI-Image-Edit.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/lora/JoyAI-Image-Edit.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/joyai_image/model_training/validate_lora/JoyAI-Image-Edit.py)|
|
||||
|
||||
## 模型推理
|
||||
|
||||
模型通过 `JoyAIImagePipeline.from_pretrained` 加载,详见[加载模型](../Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||
|
||||
`JoyAIImagePipeline` 推理的输入参数包括:
|
||||
|
||||
* `prompt`: 文本提示词,用于描述期望的图像编辑效果。
|
||||
* `negative_prompt`: 负向提示词,指定不希望出现在结果中的内容,默认为空字符串。
|
||||
* `cfg_scale`: 分类器自由引导的缩放系数,默认为 5.0。值越大,生成结果越贴近 prompt 描述。
|
||||
* `edit_image`: 待编辑的单张图像。
|
||||
* `denoising_strength`: 降噪强度,控制输入图像被重绘的程度,默认为 1.0。
|
||||
* `height`: 输出图像的高度,默认为 1024。需能被 16 整除。
|
||||
* `width`: 输出图像的宽度,默认为 1024。需能被 16 整除。
|
||||
* `seed`: 随机种子,用于控制生成的可复现性。设为 `None` 时使用随机种子。
|
||||
* `max_sequence_length`: 文本编码器处理的最大序列长度,默认为 4096。
|
||||
* `num_inference_steps`: 推理步数,默认为 30。步数越多,生成质量通常越好。
|
||||
* `tiled`: 是否启用分块处理,用于降低显存占用,默认为 False。
|
||||
* `tile_size`: 分块大小,默认为 (30, 52)。
|
||||
* `tile_stride`: 分块步幅,默认为 (15, 26)。
|
||||
* `shift`: 调度器的 shift 参数,用于控制 Flow Match 的调度曲线,默认为 4.0。
|
||||
* `progress_bar_cmd`: 进度条显示方式,默认为 tqdm。
|
||||
|
||||
## 模型训练
|
||||
|
||||
joyai_image 系列模型统一通过 `examples/joyai_image/model_training/train.py` 进行训练,脚本的参数包括:
|
||||
|
||||
* 通用训练参数
|
||||
* 数据集基础配置
|
||||
* `--dataset_base_path`: 数据集的根目录。
|
||||
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloader 的进程数量。
|
||||
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||
* 模型加载配置
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID。用逗号分隔。
|
||||
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,以 `,` 分隔。
|
||||
* `--fp8_models`: 以 FP8 格式加载的模型,目前仅支持参数不被梯度更新的模型。
|
||||
* 训练基础配置
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数。
|
||||
* `--weight_decay`: 权重衰减大小。
|
||||
* `--task`: 训练任务,默认为 `sft`。
|
||||
* 输出配置
|
||||
* `--output_path`: 模型保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||
* `--save_steps`: 保存模型的训练步数间隔。
|
||||
* LoRA 配置
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。
|
||||
* `--preset_lora_path`: 预置 LoRA 检查点路径,用于 LoRA 差分训练。
|
||||
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||
* 梯度配置
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 分辨率配置
|
||||
* `--height`: 图像/视频的高度。留空启用动态分辨率。
|
||||
* `--width`: 图像/视频的宽度。留空启用动态分辨率。
|
||||
* `--max_pixels`: 最大像素面积,动态分辨率时大于此值的图片会被缩小。
|
||||
* `--num_frames`: 视频的帧数(仅视频生成模型)。
|
||||
* JoyAI-Image 专有参数
|
||||
* `--processor_path`: Processor 路径,用于处理文本和图像的编码器输入。
|
||||
* `--initialize_model_on_cpu`: 是否在 CPU 上初始化模型,默认在加速设备上初始化。
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --local_dir ./data/diffsynth_example_dataset
|
||||
```
|
||||
|
||||
关于如何编写模型训练脚本,请参考[模型训练](../Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](https://github.com/modelscope/DiffSynth-Studio/tree/main/docs/zh/Training/)。
|
||||
@@ -1,84 +0,0 @@
|
||||
# 推理加速
|
||||
|
||||
扩散模型的去噪过程通常耗时较长。为提升推理速度,可采用多种加速技术,包含多卡并行推理、计算图编译等无损加速方案,以及 Cache、量化等有损加速方案。
|
||||
|
||||
当前扩散模型大多基于 Diffusion Transformer 构建,高效注意力机制同样是常用的加速手段。DiffSynth-Studio 目前已支持部分无损加速推理功能。本节重点从多卡并行推理和计算图编译两个维度介绍加速方法。
|
||||
|
||||
## 高效注意力机制
|
||||
注意力机制的加速细节请参考 [注意力机制实现](../API_Reference/core/attention.md)。
|
||||
|
||||
## 多卡并行推理
|
||||
DiffSynth-Studio 采用统一序列并行的多卡推理方案。在 DiT 中将 token 序列拆分至多张显卡进行并行处理。底层基于 [xDiT](https://github.com/xdit-project/xDiT) 实现。需要注意,统一序列并行会引入额外通信开销,实际加速比通常低于显卡数量。
|
||||
|
||||
目前 DiffSynth-Studio 已支持 [Wan](../Model_Details/Wan.md) 和 [MOVA](../Model_Details/Wan.md) 模型的统一序列并行加速。
|
||||
|
||||
首先安装 `xDiT` 依赖。
|
||||
```bash
|
||||
pip install "xfuser[flash-attn]>=0.4.3"
|
||||
```
|
||||
|
||||
然后使用 `torchrun` 启动多卡推理。
|
||||
```bash
|
||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
||||
```
|
||||
|
||||
构建 pipeline 时配置 `usp=True` 即可实现 USP 并行推理。代码示例如下。
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth.utils.data import save_video
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
import torch.distributed as dist
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
use_usp=True,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||
)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
if dist.get_rank() == 0:
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
## 计算图编译
|
||||
PyTorch 2.0 提供了自动计算图编译接口 [torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html),能够将 PyTorch 代码即时编译为优化内核,从而提升运行速度。由于扩散模型的推理耗时集中在 DiT 的多步去噪阶段,且 DiT 主要由基础模块堆叠而成,为缩短编译时间,DiffSynth 的 compile 功能采用仅针对基础 Transformer 模块的 [区域编译](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) 策略。
|
||||
|
||||
### Compile 使用示例
|
||||
相比常规推理,只需在调用 pipeline 前执行 `pipe.compile_pipeline()` 即可开启编译加速。具体函数定义请参阅[源代码](https://github.com/modelscope/DiffSynth-Studio/blob/166e6d2d38764209f66a74dd0fe468226536ad0f/diffsynth/diffusion/base_pipeline.py#L342)。
|
||||
|
||||
`compile_pipeline` 的输入参数主要包含两类。
|
||||
|
||||
第一类是编译模型参数 `compile_models`。以 Qwen-Image Pipeline 为例,若仅编译 DiT 模型,保持该参数为空即可。若需额外编译 VAE 等模型,可传入 `compile_models=["vae", "dit"]`。除 DiT 外,其余模型均采用整体编译策略,即把模型的 forward 函数完整编译为计算图。
|
||||
|
||||
第二类是编译策略参数。涵盖 `mode`, `dynamic`, `fullgraph` 及其他自定义选项。这些参数会直接传递给 `torch.compile` 接口。若未深入了解这些参数的具体机制,建议保持默认设置。
|
||||
|
||||
- `mode` 指定编译模式,包含 `"default"`, `"reduce-overhead"`, `"max-autotune"` 和 `"max-autotune-no-cudagraphs"`。由于 cudagraph 对计算图要求较为严格(例如可能需要配合 `torch.compiler.cudagraph_mark_step_begin()` 使用),`"reduce-overhead"` 和 `"max-autotune"` 模式可能编译失败。
|
||||
- `dynamic` 决定是否启用动态形状。对于多数生成模型,修改 prompt、开启 CFG 或调整分辨率都会改变计算图的输入张量形状。设置为 `dynamic=True` 会增加首次运行的编译时长,但支持动态形状,形状改变时无需重编译。设置为 `dynamic=False` 时首次编译较快,但任何改变输入形状的操作都会触发重新编译。对大部分场景,建议设定为 `dynamic=True`。
|
||||
- `fullgraph` 设为 `True` 时,底层会尝试将目标模型编译为单一计算图,若失败则报错。设为 `False` 时,底层会在无法连接处设置断点,将模型编译为多个独立计算图。开发者可开启 `True` 来优化编译性能,普通用户建议仅使用 `False`。
|
||||
- 其他参数配置请查阅 [api 文档](https://docs.pytorch.org/docs/stable/generated/torch.compile.html)。
|
||||
|
||||
### Compile 功能开发者文档
|
||||
若需为新接入的 pipeline 提供 compile 支持,应在 pipeline 中配置 `compilable_models` 属性以指定默认编译模型。针对该 pipeline 的 DiT 模型类,还需配置 `_repeated_blocks` 以指定参与区域编译的基础模块类型。
|
||||
|
||||
以 Qwen-Image 为例,其 pipeline 配置如下。
|
||||
```python
|
||||
self.compilable_models = ["dit"]
|
||||
```
|
||||
|
||||
其 DiT 配置如下。
|
||||
```python
|
||||
class QwenImageDiT(torch.nn.Module):
|
||||
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||
```
|
||||
@@ -13,7 +13,6 @@
|
||||
|
||||
Pipeline_Usage/Setup
|
||||
Pipeline_Usage/Model_Inference
|
||||
Pipeline_Usage/Accelerated_Inference
|
||||
Pipeline_Usage/VRAM_management
|
||||
Pipeline_Usage/Model_Training
|
||||
Pipeline_Usage/Environment_Variables
|
||||
@@ -30,9 +29,6 @@
|
||||
Model_Details/Z-Image
|
||||
Model_Details/Anima
|
||||
Model_Details/LTX-2
|
||||
Model_Details/ERNIE-Image
|
||||
Model_Details/JoyAI-Image
|
||||
Model_Details/ACE-Step
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
from modelscope import dataset_snapshot_download
|
||||
import torch
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo.wav")
|
||||
|
||||
# input audio codes as reference
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||
local_dir="data/diffsynth_example_dataset",
|
||||
allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
|
||||
)
|
||||
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
|
||||
audio_code_string = f.read().strip()
|
||||
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
task_type="cover",
|
||||
audio_code_string=audio_code_string,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo5-with-audio-codes.wav")
|
||||
@@ -1,45 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio, read_audio
|
||||
from modelscope import dataset_snapshot_download
|
||||
import torch
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||
local_dir="data/diffsynth_example_dataset",
|
||||
allow_file_pattern="ace_step/acestep-v15-base-CoverTask/audio.wav",
|
||||
)
|
||||
|
||||
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-CoverTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
||||
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
|
||||
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
task_type="cover",
|
||||
src_audio=src_audio,
|
||||
audio_cover_strength=0.5,
|
||||
denoising_strength=0.9,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-cover.wav")
|
||||
@@ -1,47 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio, read_audio
|
||||
from modelscope import dataset_snapshot_download
|
||||
import torch
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||
local_dir="data/diffsynth_example_dataset",
|
||||
allow_file_pattern="ace_step/acestep-v15-base-RepaintTask/audio.wav",
|
||||
)
|
||||
|
||||
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-RepaintTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
||||
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
|
||||
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
|
||||
# Repainting strength denotes the intensity of repainting area, where 0 means no repainting (keep the original audio) and 1 means full repainting.
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
task_type="repaint",
|
||||
src_audio=src_audio,
|
||||
repainting_ranges=[(-10, 30), (150, 200)],
|
||||
repainting_strength=1.0,
|
||||
duration=210,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-repaint.wav")
|
||||
@@ -1,31 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base.wav")
|
||||
@@ -1,38 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 SFT (supervised fine-tuned) — Text-to-Music inference example.
|
||||
|
||||
SFT variant is fine-tuned for specific music styles.
|
||||
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft.wav")
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 Turbo (continuous, shift 1-5) — Text-to-Music inference example.
|
||||
|
||||
Turbo model: no num_inference_steps or cfg_scale (use defaults).
|
||||
Continuous variant: handles shift range internally, no shift parameter needed.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous.wav")
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 Turbo (shift=1) — Text-to-Music inference example.
|
||||
|
||||
Turbo model: no num_inference_steps or cfg_scale (use defaults).
|
||||
shift=1: default value, no need to pass.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1.wav")
|
||||
@@ -1,37 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 Turbo (shift=3) — Text-to-Music inference example.
|
||||
|
||||
Turbo model: no num_inference_steps or cfg_scale (use defaults).
|
||||
shift=3: explicitly passed for this variant.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
shift=3,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3.wav")
|
||||
@@ -1,38 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 XL Base (32 layers, hidden_size=2560) — Text-to-Music inference example.
|
||||
|
||||
XL variant with larger capacity for higher quality generation.
|
||||
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base.wav")
|
||||
@@ -1,37 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 XL SFT (32 layers, supervised fine-tuned) — Text-to-Music inference example.
|
||||
|
||||
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft.wav")
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 XL Turbo (32 layers, fast generation) — Text-to-Music inference example.
|
||||
|
||||
Turbo model: no num_inference_steps or cfg_scale (use defaults).
|
||||
shift=3: explicitly passed for this variant.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo.wav")
|
||||
@@ -1,73 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 (main model, turbo) — Text-to-Music inference example (Low VRAM).
|
||||
|
||||
Low VRAM version: models are offloaded to CPU and loaded on-demand.
|
||||
Turbo model: uses num_inference_steps=8, cfg_scale=1.0.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
from modelscope import dataset_snapshot_download
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-low-vram.wav")
|
||||
|
||||
# input audio codes as reference
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||
local_dir="data/diffsynth_example_dataset",
|
||||
allow_file_pattern="ace_step/Ace-Step1.5/audio_codes_input.txt",
|
||||
)
|
||||
with open("data/diffsynth_example_dataset/ace_step/Ace-Step1.5/audio_codes_input.txt", "r") as f:
|
||||
audio_code_string = f.read().strip()
|
||||
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
task_type="cover",
|
||||
audio_code_string=audio_code_string,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo5-with-audio-codes-low-vram.wav")
|
||||
@@ -1,57 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio, read_audio
|
||||
from modelscope import dataset_snapshot_download
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||
local_dir="data/diffsynth_example_dataset",
|
||||
allow_file_pattern="ace_step/acestep-v15-base-CoverTask/audio.wav",
|
||||
)
|
||||
|
||||
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-CoverTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
||||
# audio_cover_strength controls the steps of doing cover tasks. [0, num_inference_steps * audio_cover_strength] steps will be cover steps, and the rest will be regular text-to-music generation steps.
|
||||
# denoising_strength controls how the output audio is influenced by the source audio in cover tasks.
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
task_type="cover",
|
||||
src_audio=src_audio,
|
||||
audio_cover_strength=0.5,
|
||||
denoising_strength=0.9,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-cover.wav")
|
||||
@@ -1,59 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio, read_audio
|
||||
from modelscope import dataset_snapshot_download
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/diffsynth_example_dataset",
|
||||
local_dir="data/diffsynth_example_dataset",
|
||||
allow_file_pattern="ace_step/acestep-v15-base-RepaintTask/audio.wav",
|
||||
)
|
||||
|
||||
src_audio, sr = read_audio("data/diffsynth_example_dataset/ace_step/acestep-v15-base-RepaintTask/audio.wav", resample=True, resample_rate=pipe.vae.sampling_rate)
|
||||
# repainting_ranges are in seconds, and will be converted to frames internally in the pipeline. The negative value in repainting_ranges means the padding from the start of the audio.
|
||||
# For example, repainting_ranges=[(-10, 30), (160, 200)] means we want to repaint the audio from -10s to 30s (with 10s padding before the start) and from 160s to 200s. The non-existent parts will be padded with silence.
|
||||
# Repainting strength denotes the intensity of repainting area, where 0 means no repainting (keep the original audio) and 1 means full repainting.
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
task_type="repaint",
|
||||
src_audio=src_audio,
|
||||
repainting_ranges=[(-10, 30), (150, 200)],
|
||||
repainting_strength=1.0,
|
||||
duration=210,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-repaint.wav")
|
||||
@@ -1,49 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 Base — Text-to-Music inference example (Low VRAM).
|
||||
|
||||
Low VRAM version: models are offloaded to CPU and loaded on-demand.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base-low-vram.wav")
|
||||
@@ -1,51 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 SFT (supervised fine-tuned) — Text-to-Music inference example (Low VRAM).
|
||||
|
||||
Low VRAM version: models are offloaded to CPU and loaded on-demand.
|
||||
SFT variant is fine-tuned for specific music styles.
|
||||
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft-low-vram.wav")
|
||||
@@ -1,49 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 Turbo (continuous, shift 1-5) — Text-to-Music inference example (Low VRAM).
|
||||
|
||||
Low VRAM version: models are offloaded to CPU and loaded on-demand.
|
||||
Turbo model: no num_inference_steps or cfg_scale (use defaults).
|
||||
Continuous variant: handles shift range internally, no shift parameter needed.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous-low-vram.wav")
|
||||
@@ -1,49 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 Turbo (shift=1) — Text-to-Music inference example (Low VRAM).
|
||||
|
||||
Low VRAM version: models are offloaded to CPU and loaded on-demand.
|
||||
Turbo model: no num_inference_steps or cfg_scale (use defaults).
|
||||
shift=1: default value, no need to pass.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1-low-vram.wav")
|
||||
@@ -1,50 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 Turbo (shift=3) — Text-to-Music inference example (Low VRAM).
|
||||
|
||||
Low VRAM version: models are offloaded to CPU and loaded on-demand.
|
||||
Turbo model: no num_inference_steps or cfg_scale (use defaults).
|
||||
shift=3: explicitly passed for this variant.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
shift=3,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3-low-vram.wav")
|
||||
@@ -1,51 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 XL Base — Text-to-Music inference example (Low VRAM).
|
||||
|
||||
Low VRAM version: models are offloaded to CPU and loaded on-demand.
|
||||
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
torch.cuda.reset_peak_memory_stats("cuda")
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base-low-vram.wav")
|
||||
@@ -1,50 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 XL SFT (supervised fine-tuned) — Text-to-Music inference example (Low VRAM).
|
||||
|
||||
Low VRAM version: models are offloaded to CPU and loaded on-demand.
|
||||
Non-turbo model: uses num_inference_steps=30, cfg_scale=4.0.
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft-low-vram.wav")
|
||||
@@ -1,48 +0,0 @@
|
||||
"""
|
||||
Ace-Step 1.5 XL Turbo — Text-to-Music inference example (Low VRAM).
|
||||
|
||||
Low VRAM version: models are offloaded to CPU and loaded on-demand.
|
||||
Turbo model: no num_inference_steps or cfg_scale (use defaults).
|
||||
"""
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo-low-vram.wav")
|
||||
@@ -1,18 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/Ace-Step1.5/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/Ace-Step1.5/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/Ace-Step1.5:acestep-v15-turbo/model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/Ace-Step1.5_full" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,18 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-base/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-base/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-base:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-base_full" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,18 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-sft/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-sft/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-sft:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-sft_full" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,18 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-continuous/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-continuous:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-turbo-continuous_full" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,18 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift1/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift1:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-turbo-shift1_full" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,18 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift3/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift3:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-turbo-shift3_full" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,18 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-base/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-base:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-xl-base_full" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,18 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-sft/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-sft:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-xl-sft_full" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,18 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-turbo/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-turbo:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-xl-turbo_full" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,20 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/Ace-Step1.5/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/Ace-Step1.5/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 20 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/Ace-Step1.5/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/Ace-Step1.5:acestep-v15-turbo/model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--lora_base_model "dit" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/Ace-Step1.5_lora" \
|
||||
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,20 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-base/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-base/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 20 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-base/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-base:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--lora_base_model "dit" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-base_lora" \
|
||||
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,20 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-sft/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-sft/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 20 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-sft/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-sft:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--lora_base_model "dit" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-sft_lora" \
|
||||
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,20 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-continuous/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 20 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-continuous/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-continuous:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--lora_base_model "dit" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-turbo-continuous_lora" \
|
||||
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,20 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift1/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 20 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift1/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift1:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--lora_base_model "dit" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-turbo-shift1_lora" \
|
||||
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,20 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-turbo-shift3/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 20 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-turbo-shift3/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-turbo-shift3:model.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--lora_base_model "dit" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-turbo-shift3_lora" \
|
||||
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,20 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-base/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 20 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-base/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-base:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--lora_base_model "dit" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-xl-base_lora" \
|
||||
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,20 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-sft/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 20 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-sft/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-sft:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--lora_base_model "dit" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-xl-sft_lora" \
|
||||
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,20 +0,0 @@
|
||||
# Dataset: data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "ace_step/acestep-v15-xl-turbo/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
accelerate launch examples/ace_step/model_training/train.py \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 20 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/ace_step/acestep-v15-xl-turbo/metadata.json" \
|
||||
--model_id_with_origin_paths "ACE-Step/acestep-v15-xl-turbo:model-*.safetensors,ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/model.safetensors,ACE-Step/Ace-Step1.5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "ACE-Step/Ace-Step1.5:Qwen3-Embedding-0.6B/" \
|
||||
--silence_latent_path "ACE-Step/Ace-Step1.5:acestep-v15-turbo/silence_latent.pt" \
|
||||
--lora_base_model "dit" \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/acestep-v15-xl-turbo_lora" \
|
||||
--lora_target_modules "q_proj,k_proj,v_proj,o_proj" \
|
||||
--data_file_keys "audio"
|
||||
@@ -1,144 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
import math
|
||||
import argparse
|
||||
import accelerate
|
||||
from diffsynth.core import UnifiedDataset
|
||||
from diffsynth.core.data.operators import ToAbsolutePath, LoadPureAudioWithTorchaudio
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.diffusion import *
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
class AceStepTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
tokenizer_path=None, silence_latent_path=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
preset_lora_path=None, preset_lora_model=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
fp8_models=None,
|
||||
offload_models=None,
|
||||
device="cpu",
|
||||
task="sft",
|
||||
):
|
||||
super().__init__()
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||
text_tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"))
|
||||
silence_latent_config = self.parse_path_or_model_id(silence_latent_path, default_value=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"))
|
||||
self.pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16, device=device, model_configs=model_configs,
|
||||
text_tokenizer_config=text_tokenizer_config, silence_latent_config=silence_latent_config,
|
||||
)
|
||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||
preset_lora_path, preset_lora_model,
|
||||
task=task,
|
||||
)
|
||||
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
self.fp8_models = fp8_models
|
||||
self.task = task
|
||||
self.task_to_loss = {
|
||||
"sft:data_process": lambda pipe, *args: args,
|
||||
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
}
|
||||
|
||||
def get_pipeline_inputs(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"], "positive": True}
|
||||
inputs_nega = {"positive": False}
|
||||
duration = math.floor(data['audio'][0].shape[1] / data['audio'][1]) if data.get("audio") is not None else data.get("duration", 60)
|
||||
inputs_shared = {
|
||||
"input_audio": data["audio"],
|
||||
"lyrics": data["lyrics"],
|
||||
"task_type": "text2music",
|
||||
"duration": duration,
|
||||
"bpm": data.get("bpm", 100),
|
||||
"keyscale": data.get("keyscale", "C major"),
|
||||
"timesignature": data.get("timesignature", "4"),
|
||||
"vocal_language": data.get("vocal_language", "unknown"),
|
||||
"cfg_scale": 1,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
}
|
||||
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
for unit in self.pipe.units:
|
||||
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||
return loss
|
||||
|
||||
|
||||
def ace_step_parser():
|
||||
parser = argparse.ArgumentParser(description="ACE-Step training.")
|
||||
parser = add_general_config(parser)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Tokenizer path in format model_id:origin_pattern.")
|
||||
parser.add_argument("--silence_latent_path", type=str, default=None, help="Silence latent path in format model_id:origin_pattern.")
|
||||
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ace_step_parser()
|
||||
args = parser.parse_args()
|
||||
accelerator = accelerate.Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||
)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=None,
|
||||
special_operator_map={
|
||||
"audio": ToAbsolutePath(args.dataset_base_path) >> LoadPureAudioWithTorchaudio(
|
||||
target_sample_rate=48000,
|
||||
),
|
||||
},
|
||||
)
|
||||
model = AceStepTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
silence_latent_path=args.silence_latent_path,
|
||||
trainable_models=args.trainable_models,
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_checkpoint=args.lora_checkpoint,
|
||||
preset_lora_path=args.preset_lora_path,
|
||||
preset_lora_model=args.preset_lora_model,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
fp8_models=args.fp8_models,
|
||||
offload_models=args.offload_models,
|
||||
task=args.task,
|
||||
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
||||
)
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
)
|
||||
launcher_map = {
|
||||
"sft:data_process": launch_data_process_task,
|
||||
"sft": launch_training_task,
|
||||
"sft:train": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||
@@ -1,35 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
from diffsynth import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Ace-Step1.5_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "Ace-Step1.5_full.wav")
|
||||
@@ -1,35 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
from diffsynth import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/acestep-v15-base_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=42,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base_full.wav")
|
||||
@@ -1,35 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
from diffsynth import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/acestep-v15-sft_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft_full.wav")
|
||||
@@ -1,35 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
from diffsynth import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/acestep-v15-turbo-continuous_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous_full.wav")
|
||||
@@ -1,35 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
from diffsynth import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/acestep-v15-turbo-shift1_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1_full.wav")
|
||||
@@ -1,35 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
from diffsynth import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/acestep-v15-turbo-shift3_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3_full.wav")
|
||||
@@ -1,35 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
from diffsynth import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/acestep-v15-xl-base_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base_full.wav")
|
||||
@@ -1,35 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
from diffsynth import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/acestep-v15-xl-sft_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft_full.wav")
|
||||
@@ -1,35 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
from diffsynth import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
state_dict = load_state_dict("models/train/acestep-v15-xl-turbo_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=8,
|
||||
cfg_scale=1.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo_full.wav")
|
||||
@@ -1,33 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Ace-Step1.5_lora/epoch-9.safetensors", alpha=1)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "Ace-Step1.5_lora.wav")
|
||||
@@ -1,33 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-base", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/acestep-v15-base_lora/epoch-9.safetensors", alpha=1)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-base_lora.wav")
|
||||
@@ -1,33 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-sft", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/acestep-v15-sft_lora/epoch-9.safetensors", alpha=1)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-sft_lora.wav")
|
||||
@@ -1,33 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-continuous", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-continuous_lora/epoch-9.safetensors", alpha=1)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-continuous_lora.wav")
|
||||
@@ -1,33 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift1", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-shift1_lora/epoch-9.safetensors", alpha=1)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift1_lora.wav")
|
||||
@@ -1,33 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-turbo-shift3", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/acestep-v15-turbo-shift3_lora/epoch-9.safetensors", alpha=1)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-turbo-shift3_lora.wav")
|
||||
@@ -1,33 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-base", origin_file_pattern="model-*.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-base_lora/epoch-9.safetensors", alpha=1)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-base_lora.wav")
|
||||
@@ -1,33 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-sft", origin_file_pattern="model-*.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-sft_lora/epoch-9.safetensors", alpha=1)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-sft_lora.wav")
|
||||
@@ -1,33 +0,0 @@
|
||||
from diffsynth.pipelines.ace_step import AceStepPipeline, ModelConfig
|
||||
from diffsynth.utils.data.audio import save_audio
|
||||
import torch
|
||||
|
||||
|
||||
pipe = AceStepPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="ACE-Step/acestep-v15-xl-turbo", origin_file_pattern="model-*.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
|
||||
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
text_tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/"),
|
||||
silence_latent_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/acestep-v15-xl-turbo_lora/epoch-9.safetensors", alpha=1)
|
||||
|
||||
prompt = "An explosive, high-energy pop-rock track with a strong anime theme song feel. The song kicks off with a catchy, synthesized brass fanfare over a driving rock beat with punchy drums and a solid bassline. A powerful, clear male vocal enters with a theatrical and energetic delivery, soaring through the verses and hitting powerful high notes in the chorus. The arrangement is dense and dynamic, featuring rhythmic electric guitar chords, brief instrumental breaks with synth flourishes, and a consistent, danceable groove throughout. The overall mood is triumphant, adventurous, and exhilarating."
|
||||
lyrics = '[Intro - Synth Brass Fanfare]\n\n[Verse 1]\n黑夜里的风吹过耳畔\n甜蜜时光转瞬即万\n脚步飘摇在星光上\n心追节奏心跳狂乱\n耳边传来电吉他呼唤\n手指轻触碰点流点燃\n梦在云端任它蔓延\n疯狂跳跃自由无间\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Instrumental Break - Synth Brass Melody]\n\n[Verse 2]\n鼓点撞击黑夜的底端\n跳动节拍连接你我俩\n在这里让灵魂发光\n燃尽所有不留遗憾\n\n[Instrumental Break - Synth Brass Melody]\n\n[Bridge]\n光影交错彼此的视线\n霓虹之下夜空的蔚蓝\n月光洒下温热心田\n追逐梦想它不会遥远\n\n[Chorus]\n心电感应在震动间\n拥抱未来勇敢冒险\n那旋律在心中无限\n世界变得如此耀眼\n\n[Outro - Instrumental with Synth Brass Melody]\n[Song ends abruptly]'
|
||||
audio = pipe(
|
||||
prompt=prompt,
|
||||
lyrics=lyrics,
|
||||
duration=160,
|
||||
bpm=100,
|
||||
keyscale="B minor",
|
||||
timesignature="4",
|
||||
vocal_language="zh",
|
||||
seed=1,
|
||||
num_inference_steps=30,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
save_audio(audio, pipe.vae.sampling_rate, "acestep-v15-xl-turbo_lora.wav")
|
||||
@@ -1,25 +0,0 @@
|
||||
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = ErnieImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device='cuda',
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image-Turbo", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="一只黑白相间的中华田园犬",
|
||||
negative_prompt="",
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=42,
|
||||
num_inference_steps=8,
|
||||
cfg_scale=1.0,
|
||||
sigma_shift=4.0,
|
||||
)
|
||||
image.save("output_turbo.jpg")
|
||||
@@ -1,24 +0,0 @@
|
||||
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = ErnieImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device='cuda',
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="一只黑白相间的中华田园犬",
|
||||
negative_prompt="",
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=42,
|
||||
num_inference_steps=50,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
image.save("output.jpg")
|
||||
@@ -1,37 +0,0 @@
|
||||
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
pipe = ErnieImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device='cuda',
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image-Turbo", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="一只黑白相间的中华田园犬",
|
||||
negative_prompt="",
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=42,
|
||||
num_inference_steps=8,
|
||||
cfg_scale=1.0,
|
||||
sigma_shift=4.0,
|
||||
)
|
||||
image.save("output_turbo.jpg")
|
||||
@@ -1,36 +0,0 @@
|
||||
from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
|
||||
pipe = ErnieImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device='cuda',
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="PaddlePaddle/ERNIE-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="一只黑白相间的中华田园犬",
|
||||
negative_prompt="",
|
||||
height=1024,
|
||||
width=1024,
|
||||
seed=42,
|
||||
num_inference_steps=50,
|
||||
cfg_scale=4.0,
|
||||
)
|
||||
image.save("output.jpg")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user