Compare commits

..

1 Commits

Author SHA1 Message Date
Artiprocher
59f512b574 add acestep models 2026-04-02 10:58:45 +08:00
12 changed files with 3177 additions and 292 deletions

View File

@@ -884,4 +884,40 @@ mova_series = [
"model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge", "model_class": "diffsynth.models.mova_dual_tower_bridge.DualTowerConditionalBridge",
}, },
] ]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series
ace_step_series = [
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors")
"model_hash": "3509bea17b0e8cffc3dd4a15cc7899d0",
"model_name": "ace_step_text_encoder",
"model_class": "diffsynth.models.ace_step_text_encoder.AceStepTextEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_text_encoder.AceStepTextEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
"model_hash": "51420834e54474986a7f4be0e4d6f687",
"model_name": "ace_step_vae",
"model_class": "diffsynth.models.ace_step_vae.AceStepVAE",
"extra_kwargs": {
"encoder_hidden_size": 128,
"downsampling_ratios": [2, 4, 4, 6, 10],
"channel_multiples": [1, 2, 4, 8, 16],
"decoder_channels": 128,
"decoder_input_channels": 64,
"audio_channels": 2,
"sampling_rate": 48000
}
},
{
# Example: ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors")
"model_hash": "ba29d8bddbb6ace65675f6a757a13c00",
"model_name": "ace_step_dit",
"model_class": "diffsynth.models.ace_step_dit.AceStepConditionGenerationModelWrapper",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ace_step_dit.AceStepDiTStateDictConverter",
"extra_kwargs": {
"config_path": "models/ACE-Step/Ace-Step1.5/acestep-v15-turbo"
}
},
]
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series + anima_series + mova_series + ace_step_series

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
from transformers import Qwen3Model, Qwen3Config
import torch
class AceStepTextEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
config = Qwen3Config(**{
"architectures": ["Qwen3Model"],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 32768,
"max_window_layers": 28,
"model_type": "qwen3",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 1000000,
"sliding_window": None,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151669
})
self.model = Qwen3Model(config)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)

View File

@@ -0,0 +1,416 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
class Snake1d(nn.Module):
"""
A 1-dimensional Snake activation function module.
"""
def __init__(self, hidden_dim, logscale=True):
super().__init__()
self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.alpha.requires_grad = True
self.beta.requires_grad = True
self.logscale = logscale
def forward(self, hidden_states):
shape = hidden_states.shape
alpha = self.alpha if not self.logscale else torch.exp(self.alpha)
beta = self.beta if not self.logscale else torch.exp(self.beta)
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
hidden_states = hidden_states.reshape(shape)
return hidden_states
class OobleckResidualUnit(nn.Module):
"""
A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
"""
def __init__(self, dimension: int = 16, dilation: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.snake1 = Snake1d(dimension)
self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
self.snake2 = Snake1d(dimension)
self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
def forward(self, hidden_state):
output_tensor = hidden_state
output_tensor = self.conv1(self.snake1(output_tensor))
output_tensor = self.conv2(self.snake2(output_tensor))
padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
if padding > 0:
hidden_state = hidden_state[..., padding:-padding]
output_tensor = hidden_state + output_tensor
return output_tensor
class OobleckEncoderBlock(nn.Module):
"""Encoder block used in Oobleck encoder."""
def __init__(self, input_dim, output_dim, stride: int = 1):
super().__init__()
self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
self.snake1 = Snake1d(input_dim)
self.conv1 = weight_norm(
nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
)
def forward(self, hidden_state):
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.snake1(self.res_unit3(hidden_state))
hidden_state = self.conv1(hidden_state)
return hidden_state
class OobleckDecoderBlock(nn.Module):
"""Decoder block used in Oobleck decoder."""
def __init__(self, input_dim, output_dim, stride: int = 1):
super().__init__()
self.snake1 = Snake1d(input_dim)
self.conv_t1 = weight_norm(
nn.ConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
)
)
self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
def forward(self, hidden_state):
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv_t1(hidden_state)
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.res_unit3(hidden_state)
return hidden_state
class OobleckDiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.scale = parameters.chunk(2, dim=1)
self.std = nn.functional.softplus(self.scale) + 1e-4
self.var = self.std * self.std
self.logvar = torch.log(self.var)
self.deterministic = deterministic
def sample(self, generator: torch.Generator = None) -> torch.Tensor:
device = self.parameters.device
dtype = self.parameters.dtype
sample = torch.randn(self.mean.shape, generator=generator, device=device, dtype=dtype)
x = self.mean + self.std * sample
return x
def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
else:
normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
var_ratio = self.var / other.var
logvar_diff = self.logvar - other.logvar
kl = normalized_diff + var_ratio + logvar_diff - 1
kl = kl.sum(1).mean()
return kl
def mode(self) -> torch.Tensor:
return self.mean
@dataclass
class AutoencoderOobleckOutput:
"""
Output of AutoencoderOobleck encoding method.
Args:
latent_dist (`OobleckDiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and standard deviation of
`OobleckDiagonalGaussianDistribution`. `OobleckDiagonalGaussianDistribution` allows for sampling latents
from the distribution.
"""
latent_dist: "OobleckDiagonalGaussianDistribution"
@dataclass
class OobleckDecoderOutput:
r"""
Output of decoding method.
Args:
sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.Tensor
class OobleckEncoder(nn.Module):
"""Oobleck Encoder"""
def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples):
super().__init__()
strides = downsampling_ratios
channel_multiples = [1] + channel_multiples
# Create first convolution
self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
self.block = []
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride_index, stride in enumerate(strides):
self.block += [
OobleckEncoderBlock(
input_dim=encoder_hidden_size * channel_multiples[stride_index],
output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
stride=stride,
)
]
self.block = nn.ModuleList(self.block)
d_model = encoder_hidden_size * channel_multiples[-1]
self.snake1 = Snake1d(d_model)
self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for module in self.block:
hidden_state = module(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
class OobleckDecoder(nn.Module):
"""Oobleck Decoder"""
def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples):
super().__init__()
strides = upsampling_ratios
channel_multiples = [1] + channel_multiples
# Add first conv layer
self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
# Add upsampling + MRF blocks
block = []
for stride_index, stride in enumerate(strides):
block += [
OobleckDecoderBlock(
input_dim=channels * channel_multiples[len(strides) - stride_index],
output_dim=channels * channel_multiples[len(strides) - stride_index - 1],
stride=stride,
)
]
self.block = nn.ModuleList(block)
output_dim = channels
self.snake1 = Snake1d(output_dim)
self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for layer in self.block:
hidden_state = layer(hidden_state)
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv2(hidden_state)
return hidden_state
class AceStepVAE(nn.Module):
r"""
An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
introduced in Stable Audio.
Parameters:
encoder_hidden_size (`int`, *optional*, defaults to 128):
Intermediate representation dimension for the encoder.
downsampling_ratios (`list[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
channel_multiples (`list[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
Multiples used to determine the hidden sizes of the hidden layers.
decoder_channels (`int`, *optional*, defaults to 128):
Intermediate representation dimension for the decoder.
decoder_input_channels (`int`, *optional*, defaults to 64):
Input dimension for the decoder. Corresponds to the latent dimension.
audio_channels (`int`, *optional*, defaults to 2):
Number of channels in the audio data. Either 1 for mono or 2 for stereo.
sampling_rate (`int`, *optional*, defaults to 44100):
The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
"""
def __init__(
self,
encoder_hidden_size=128,
downsampling_ratios=[2, 4, 4, 8, 8],
channel_multiples=[1, 2, 4, 8, 16],
decoder_channels=128,
decoder_input_channels=64,
audio_channels=2,
sampling_rate=44100,
):
super().__init__()
self.encoder_hidden_size = encoder_hidden_size
self.downsampling_ratios = downsampling_ratios
self.decoder_channels = decoder_channels
self.upsampling_ratios = downsampling_ratios[::-1]
self.hop_length = int(np.prod(downsampling_ratios))
self.sampling_rate = sampling_rate
self.encoder = OobleckEncoder(
encoder_hidden_size=encoder_hidden_size,
audio_channels=audio_channels,
downsampling_ratios=downsampling_ratios,
channel_multiples=channel_multiples,
)
self.decoder = OobleckDecoder(
channels=decoder_channels,
input_channels=decoder_input_channels,
audio_channels=audio_channels,
upsampling_ratios=self.upsampling_ratios,
channel_multiples=channel_multiples,
)
self.use_slicing = False
def encode(self, x: torch.Tensor, return_dict: bool = True):
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
posterior = OobleckDiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderOobleckOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True):
dec = self.decoder(z)
if not return_dict:
return (dec,)
return OobleckDecoderOutput(sample=dec)
def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None):
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.OobleckDecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.OobleckDecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.OobleckDecoderOutput`] is returned, otherwise a plain `tuple`
is returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return OobleckDecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: torch.Generator = None,
):
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`OobleckDecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return OobleckDecoderOutput(sample=dec)

View File

@@ -1,4 +1,4 @@
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessor from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
import torch import torch
@@ -40,7 +40,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
value_bias = False value_bias = False
) )
super().__init__(config) super().__init__(config)
self.processor = DINOv3ViTImageProcessor( self.processor = DINOv3ViTImageProcessorFast(
crop_size = None, crop_size = None,
data_format = "channels_first", data_format = "channels_first",
default_to_square = True, default_to_square = True,
@@ -56,7 +56,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
0.456, 0.456,
0.406 0.406
], ],
image_processor_type = "DINOv3ViTImageProcessor", image_processor_type = "DINOv3ViTImageProcessorFast",
image_std = [ image_std = [
0.229, 0.229,
0.224, 0.224,

View File

@@ -1,5 +1,5 @@
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessor from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
import torch import torch
from diffsynth.core.device.npu_compatible_device import get_device_type from diffsynth.core.device.npu_compatible_device import get_device_type
@@ -90,7 +90,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
transformers_version = "4.57.1" transformers_version = "4.57.1"
) )
super().__init__(config) super().__init__(config)
self.processor = Siglip2ImageProcessor( self.processor = Siglip2ImageProcessorFast(
**{ **{
"data_format": "channels_first", "data_format": "channels_first",
"default_to_square": True, "default_to_square": True,
@@ -106,7 +106,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
0.5, 0.5,
0.5 0.5
], ],
"image_processor_type": "Siglip2ImageProcessor", "image_processor_type": "Siglip2ImageProcessorFast",
"image_std": [ "image_std": [
0.5, 0.5,
0.5, 0.5,

View File

@@ -0,0 +1,217 @@
import torch, math
from PIL import Image
from typing import Union
from tqdm import tqdm
from einops import rearrange
import numpy as np
from math import prod
from transformers import AutoTokenizer
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
from ..utils.lora.merge import merge_lora
from ..core.device.npu_compatible_device import get_device_type
from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline
from ..models.ace_step_text_encoder import AceStepTextEncoder
from ..models.ace_step_vae import AceStepVAE
from ..models.ace_step_dit import AceStepConditionGenerationModelWrapper
class AceStepAudioPipeline(BasePipeline):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.text_encoder: AceStepTextEncoder = None
self.dit: AceStepConditionGenerationModelWrapper = None
self.vae: AceStepVAE = None
self.scheduler = FlowMatchScheduler()
self.tokenizer: AutoTokenizer = None
self.in_iteration_models = ("dit",)
self.units = []
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B"),
vram_limit: float = None,
):
# Initialize pipeline
pipe = AceStepAudioPipeline(device=device, torch_dtype=torch_dtype)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
# Fetch models
pipe.text_encoder = model_pool.fetch_model("ace_step_text_encoder")
pipe.dit = model_pool.fetch_model("ace_step_dit")
pipe.vae = model_pool.fetch_model("ace_step_vae")
if tokenizer_config is not None:
tokenizer_config.download_if_necessary()
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
@torch.no_grad()
def __call__(
self,
caption: str,
lyrics: str = "",
duration: float = 160,
bpm: int = None,
keyscale: str = "",
timesignature: str = "",
vocal_language: str = "zh",
instrumental: bool = False,
inference_steps: int = 8,
guidance_scale: float = 3.0,
seed: int = None,
):
# Format text prompt with metadata
text_prompt = self._format_text_prompt(caption, bpm, keyscale, timesignature, duration)
lyrics_text = self._format_lyrics(lyrics, vocal_language, instrumental)
# Tokenize
text_inputs = self.tokenizer(
text_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
).to(self.device)
lyrics_inputs = self.tokenizer(
lyrics_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048,
).to(self.device)
# Encode text and lyrics
text_outputs = self.text_encoder(
input_ids=text_inputs["input_ids"],
attention_mask=text_inputs["attention_mask"],
)
lyrics_outputs = self.text_encoder(
input_ids=lyrics_inputs["input_ids"],
attention_mask=lyrics_inputs["attention_mask"],
)
# Get hidden states
text_hidden_states = text_outputs.last_hidden_state
lyric_hidden_states = lyrics_outputs.last_hidden_state
# Prepare generation parameters
latent_frames = int(duration * 46.875) # 48000 / 1024 ≈ 46.875 Hz
# For text2music task, use silence_latent as src_latents
# silence_latent will be tokenized/detokenized to get lm_hints_25Hz (127 dims)
# which will be used as context for generation
if self.silence_latent is not None:
# Slice or pad silence_latent to match latent_frames
if self.silence_latent.shape[1] >= latent_frames:
src_latents = self.silence_latent[:, :latent_frames, :].to(device=self.device, dtype=self.torch_dtype)
else:
# Pad with zeros if silence_latent is shorter
pad_len = latent_frames - self.silence_latent.shape[1]
src_latents = torch.cat([
self.silence_latent.to(device=self.device, dtype=self.torch_dtype),
torch.zeros(1, pad_len, self.src_latent_channels, device=self.device, dtype=self.torch_dtype)
], dim=1)
else:
# Fallback: create random latents if silence_latent is not loaded
src_latents = torch.randn(1, latent_frames, self.src_latent_channels,
device=self.device, dtype=self.torch_dtype)
# Create attention mask
attention_mask = torch.ones(1, latent_frames, device=self.device, dtype=self.torch_dtype)
# Use silence_latent for the silence_latent parameter as well
silence_latent = src_latents
# Chunk masks and is_covers (for text2music, these are all zeros)
# chunk_masks shape: [batch, latent_frames, 1]
chunk_masks = torch.zeros(1, latent_frames, 1, device=self.device, dtype=self.torch_dtype)
is_covers = torch.zeros(1, device=self.device, dtype=self.torch_dtype)
# Reference audio (empty for text2music)
# For text2music mode, we need empty reference audio
# refer_audio_acoustic_hidden_states_packed: [batch, num_segments, hidden_dim]
# refer_audio_order_mask: [num_segments] - indicates which batch each segment belongs to
refer_audio_acoustic_hidden_states_packed = torch.zeros(1, 1, 64, device=self.device, dtype=self.torch_dtype)
refer_audio_order_mask = torch.zeros(1, device=self.device, dtype=torch.long) # 1-d tensor
# Generate audio latents using DiT model
generation_result = self.dit.model.generate_audio(
text_hidden_states=text_hidden_states,
text_attention_mask=text_inputs["attention_mask"],
lyric_hidden_states=lyric_hidden_states,
lyric_attention_mask=lyrics_inputs["attention_mask"],
refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed,
refer_audio_order_mask=refer_audio_order_mask,
src_latents=src_latents,
chunk_masks=chunk_masks,
is_covers=is_covers,
silence_latent=silence_latent,
attention_mask=attention_mask,
seed=seed if seed is not None else 42,
fix_nfe=inference_steps,
shift=guidance_scale,
)
# Extract target latents from result dictionary
generated_latents = generation_result["target_latents"]
# Decode latents to audio
# generated_latents shape: [batch, latent_frames, 64]
# VAE expects: [batch, latent_frames, 64]
audio_output = self.vae.decode(generated_latents, return_dict=True)
audio = audio_output.sample
# Post-process audio
audio = self._postprocess_audio(audio)
self.load_models_to_device([])
return audio
def _format_text_prompt(self, caption, bpm, keyscale, timesignature, duration):
"""Format text prompt with metadata"""
prompt = "# Instruction\nFill the audio semantic mask based on the given conditions:\n\n"
prompt += f"# Caption\n{caption}\n\n"
prompt += "# Metas\n"
if bpm:
prompt += f"- bpm: {bpm}\n"
if timesignature:
prompt += f"- timesignature: {timesignature}\n"
if keyscale:
prompt += f"- keyscale: {keyscale}\n"
prompt += f"- duration: {int(duration)} seconds\n"
prompt += "<|endoftext|>"
return prompt
def _format_lyrics(self, lyrics, vocal_language, instrumental):
"""Format lyrics with language"""
if instrumental or not lyrics:
lyrics = "[Instrumental]"
lyrics_text = f"# Languages\n{vocal_language}\n\n# Lyric\n{lyrics}<|endoftext|>"
return lyrics_text
def _postprocess_audio(self, audio):
"""Post-process audio tensor"""
# Ensure audio is on CPU and in float32
audio = audio.to(device="cpu", dtype=torch.float32)
# Normalize to [-1, 1]
max_val = torch.abs(audio).max()
if max_val > 0:
audio = audio / max_val
return audio

View File

@@ -95,7 +95,7 @@ class ZImagePipeline(BasePipeline):
def __call__( def __call__(
self, self,
# Prompt # Prompt
prompt: str = "", prompt: str,
negative_prompt: str = "", negative_prompt: str = "",
cfg_scale: float = 1.0, cfg_scale: float = 1.0,
# Image # Image
@@ -109,7 +109,7 @@ class ZImagePipeline(BasePipeline):
width: int = 1024, width: int = 1024,
# Randomness # Randomness
seed: int = None, seed: int = None,
rand_device: Union[str, torch.device] = "cpu", rand_device: str = "cpu",
# Steps # Steps
num_inference_steps: int = 8, num_inference_steps: int = 8,
sigma_shift: float = None, sigma_shift: float = None,

View File

@@ -0,0 +1,15 @@
def AceStepDiTStateDictConverter(state_dict):
"""
Convert ACE-Step DiT state dict to add 'model.' prefix for wrapper class.
The wrapper class has self.model = AceStepConditionGenerationModel(config),
so all keys need to be prefixed with 'model.'
"""
state_dict_ = {}
keys = state_dict.keys() if hasattr(state_dict, 'keys') else state_dict
for k in keys:
v = state_dict[k]
if not k.startswith("model."):
k = "model." + k
state_dict_[k] = v
return state_dict_

View File

@@ -0,0 +1,19 @@
def AceStepTextEncoderStateDictConverter(state_dict):
"""
将 ACE-Step Text Encoder 权重添加 model. 前缀
Args:
state_dict: 原始的 state dict可能是 dict 或 DiskMap
Returns:
转换后的 state dict所有 key 添加 "model." 前缀
"""
state_dict_ = {}
# 处理 DiskMap 或普通 dict
keys = state_dict.keys() if hasattr(state_dict, 'keys') else state_dict
for k in keys:
v = state_dict[k]
if not k.startswith("model."):
k = "model." + k
state_dict_[k] = v
return state_dict_

View File

@@ -0,0 +1,14 @@
from diffsynth.pipelines.ace_step_audio import AceStepAudioPipeline, ModelConfig
import torch
pipe = AceStepAudioPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/model.safetensors"),
ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="Qwen3-Embedding-0.6B"),
)

View File

@@ -1,283 +0,0 @@
import importlib, inspect, pkgutil, traceback, torch, os, re
from typing import Union, List, Optional, Tuple, Iterable, Dict
from contextlib import contextmanager
import streamlit as st
from diffsynth import ModelConfig
from diffsynth.diffusion.base_pipeline import ControlNetInput
from PIL import Image
from tqdm import tqdm
st.set_page_config(layout="wide")
class StreamlitTqdmWrapper:
"""Wrapper class that combines tqdm and streamlit progress bar"""
def __init__(self, iterable, st_progress_bar=None):
self.iterable = iterable
self.st_progress_bar = st_progress_bar
self.tqdm_bar = tqdm(iterable)
self.total = len(iterable) if hasattr(iterable, '__len__') else None
self.current = 0
def __iter__(self):
for item in self.tqdm_bar:
if self.st_progress_bar is not None and self.total is not None:
self.current += 1
self.st_progress_bar.progress(self.current / self.total)
yield item
def __enter__(self):
return self
def __exit__(self, *args):
if hasattr(self.tqdm_bar, '__exit__'):
self.tqdm_bar.__exit__(*args)
@contextmanager
def catch_error(error_value):
try:
yield
except Exception as e:
error_message = traceback.format_exc()
print(f"Error {error_value}:\n{error_message}")
def parse_model_configs_from_an_example(path):
model_configs = []
with open(path, "r") as f:
for code in f.readlines():
code = code.strip()
if not code.startswith("ModelConfig"):
continue
pairs = re.findall(r'(\w+)\s*=\s*["\']([^"\']+)["\']', code)
config_dict = {k: v for k, v in pairs}
model_configs.append(ModelConfig(model_id=config_dict["model_id"], origin_file_pattern=config_dict["origin_file_pattern"]))
return model_configs
def list_examples(path, keyword=None):
examples = []
if os.path.isdir(path):
for file_name in os.listdir(path):
examples.extend(list_examples(os.path.join(path, file_name), keyword=keyword))
elif path.endswith(".py"):
with open(path, "r") as f:
code = f.read()
if keyword is None or keyword in code:
examples.extend([path])
return examples
def parse_available_pipelines():
from diffsynth.diffusion.base_pipeline import BasePipeline
import diffsynth.pipelines as _pipelines_pkg
available_pipelines = {}
for _, name, _ in pkgutil.iter_modules(_pipelines_pkg.__path__):
with catch_error(f"Failed: import diffsynth.pipelines.{name}"):
mod = importlib.import_module(f"diffsynth.pipelines.{name}")
classes = {
cls_name: cls for cls_name, cls in inspect.getmembers(mod, inspect.isclass)
if issubclass(cls, BasePipeline) and cls is not BasePipeline and cls.__module__ == mod.__name__
}
available_pipelines.update(classes)
return available_pipelines
def parse_available_examples(path, available_pipelines):
available_examples = {}
for pipeline_name in available_pipelines:
examples = ["None"] + list_examples(path, keyword=f"{pipeline_name}.from_pretrained")
available_examples[pipeline_name] = examples
return available_examples
def draw_selectbox(label, options, option_map, value=None, disabled=False):
default_index = 0 if value is None else tuple(options).index([option for option in option_map if option_map[option]==value][0])
option = st.selectbox(label=label, options=tuple(options), index=default_index, disabled=disabled)
return option_map.get(option)
def parse_params(fn):
params = []
for name, param in inspect.signature(fn).parameters.items():
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else None
default = param.default if param.default is not inspect.Parameter.empty else None
params.append({"name": name, "dtype": annotation, "value": default})
return params
def draw_model_config(model_config=None, key_suffix="", disabled=False):
with st.container(border=True):
if model_config is None:
model_config = ModelConfig()
path = st.text_input(label="path", key="path" + key_suffix, value=model_config.path, disabled=disabled)
col1, col2 = st.columns(2)
with col1:
model_id = st.text_input(label="model_id", key="model_id" + key_suffix, value=model_config.model_id, disabled=disabled)
with col2:
origin_file_pattern = st.text_input(label="origin_file_pattern", key="origin_file_pattern" + key_suffix, value=model_config.origin_file_pattern, disabled=disabled)
model_config = ModelConfig(
path=None if path == "" else path,
model_id=model_id,
origin_file_pattern=origin_file_pattern,
)
return model_config
def draw_multi_model_config(name="", value=None, disabled=False):
model_configs = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
model_config = draw_model_config(key_suffix=f"_{name}_{i}", model_config=None if value is None else value[i], disabled=disabled)
model_configs.append(model_config)
return model_configs
def draw_single_model_config(name="", value=None, disabled=False):
with st.container(border=True):
st.markdown(name)
model_config = draw_model_config(value, key_suffix=f"_{name}", disabled=disabled)
return model_config
def draw_multi_images(name="", value=None, disabled=False):
images = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
image = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], key=f"{name}_{i}", disabled=disabled)
if image is not None: images.append(Image.open(image))
return images
def draw_controlnet_input(name="", value=None, disabled=False):
with st.container(border=True):
st.markdown(name)
controlnet_id = st.number_input("controlnet_id", value=0, min_value=0, max_value=20, step=1, key=f"{name}_controlnet_id")
scale = st.number_input("scale", value=1.0, min_value=0.0, max_value=10.0, key=f"{name}_scale")
image = st.file_uploader("image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_image")
if image is not None: image = Image.open(image)
inpaint_image = st.file_uploader("inpaint_image", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_image")
if inpaint_image is not None: inpaint_image = Image.open(inpaint_image)
inpaint_mask = st.file_uploader("inpaint_mask", type=["png", "jpg", "jpeg", "webp"], disabled=disabled, key=f"{name}_inpaint_mask")
if inpaint_mask is not None: inpaint_mask = Image.open(inpaint_mask)
return ControlNetInput(controlnet_id=controlnet_id, scale=scale, image=image, inpaint_image=inpaint_image, inpaint_mask=inpaint_mask)
def draw_controlnet_inputs(name, value=None, disabled=False):
controlnet_inputs = []
with st.container(border=True):
st.markdown(name)
num = st.number_input(f"num_{name}", min_value=0, max_value=20, value=0 if value is None else len(value), disabled=disabled)
for i in range(num):
controlnet_input = draw_controlnet_input(name=f"{name}_{i}", value=None, disabled=disabled)
controlnet_inputs.append(controlnet_input)
return controlnet_inputs
def draw_ui_element(name, dtype, value):
unsupported_dtype = [
Dict[str, torch.Tensor],
torch.Tensor,
]
if dtype in unsupported_dtype:
return
if value is None:
with st.container(border=True):
enable = st.checkbox(f"Enable {name}", value=False)
ui = draw_ui_element_safely(name, dtype, value, disabled=not enable)
if enable:
return ui
else:
return None
else:
return draw_ui_element_safely(name, dtype, value)
def draw_ui_element_safely(name, dtype, value, disabled=False):
if dtype == torch.dtype:
option_map = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == Union[str, torch.device]:
option_map = {"cuda": "cuda", "cpu": "cpu"}
ui = draw_selectbox(name, option_map.keys(), option_map, value=value, disabled=disabled)
elif dtype == bool:
ui = st.checkbox(name, value, disabled=disabled)
elif dtype == ModelConfig:
ui = draw_single_model_config(name, value, disabled=disabled)
elif dtype == list[ModelConfig]:
if name == "model_configs" and "model_configs_from_example" in st.session_state:
model_configs = st.session_state["model_configs_from_example"]
del st.session_state["model_configs_from_example"]
ui = draw_multi_model_config(name, model_configs, disabled=disabled)
else:
ui = draw_multi_model_config(name, disabled=disabled)
elif dtype == str:
if "prompt" in name:
ui = st.text_area(name, value, height=3, disabled=disabled)
else:
ui = st.text_input(name, value, disabled=disabled)
elif dtype == float:
ui = st.number_input(name, value, disabled=disabled)
elif dtype == int:
ui = st.number_input(name, value, step=1, disabled=disabled)
elif dtype == Image.Image:
ui = st.file_uploader(name, type=["png", "jpg", "jpeg", "webp"], disabled=disabled)
if ui is not None: ui = Image.open(ui)
elif dtype == List[Image.Image]:
ui = draw_multi_images(name, value, disabled=disabled)
elif dtype == List[ControlNetInput]:
ui = draw_controlnet_inputs(name, value, disabled=disabled)
elif dtype is None:
if name == "progress_bar_cmd":
ui = value
else:
st.markdown(f"(`{name}` is not not configurable in WebUI). dtype: `{dtype}`.")
ui = value
return ui
def launch_webui():
input_col, output_col = st.columns(2)
with input_col:
if "available_pipelines" not in st.session_state:
st.session_state["available_pipelines"] = parse_available_pipelines()
if "available_examples" not in st.session_state:
st.session_state["available_examples"] = parse_available_examples("./examples", st.session_state["available_pipelines"])
with st.expander("Pipeline", expanded=True):
pipeline_class = draw_selectbox("Pipeline Class", st.session_state["available_pipelines"].keys(), st.session_state["available_pipelines"], value=st.session_state["available_pipelines"]["ZImagePipeline"])
example = st.selectbox("Parse model configs from an example (optional)", st.session_state["available_examples"][pipeline_class.__name__])
if example != "None":
st.session_state["model_configs_from_example"] = parse_model_configs_from_an_example(example)
if st.button("Step 1: Parse Pipeline", type="primary"):
st.session_state["pipeline_class"] = pipeline_class
if "pipeline_class" not in st.session_state:
return
with st.expander("Model", expanded=True):
input_params = {}
params = parse_params(pipeline_class.from_pretrained)
for param in params:
input_params[param["name"]] = draw_ui_element(**param)
if st.button("Step 2: Load Models", type="primary"):
with st.spinner("Loading models", show_time=True):
if "pipe" in st.session_state:
del st.session_state["pipe"]
torch.cuda.empty_cache()
st.session_state["pipe"] = pipeline_class.from_pretrained(**input_params)
if "pipe" not in st.session_state:
return
with st.expander("Input", expanded=True):
pipe = st.session_state["pipe"]
input_params = {}
params = parse_params(pipe.__call__)
for param in params:
if param["name"] in ["self"]:
continue
input_params[param["name"]] = draw_ui_element(**param)
with output_col:
if st.button("Step 3: Generate", type="primary"):
if "progress_bar_cmd" in input_params:
input_params["progress_bar_cmd"] = lambda iterable: StreamlitTqdmWrapper(iterable, st.progress(0))
result = pipe(**input_params)
st.session_state["result"] = result
if "result" in st.session_state:
result = st.session_state["result"]
if isinstance(result, Image.Image):
st.image(result)
else:
print(f"unsupported result format: {result}")
launch_webui()
# streamlit run examples/dev_tools/webui.py --server.fileWatcherType none