mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
style
This commit is contained in:
@@ -3,12 +3,10 @@ ACE-Step Pipeline for DiffSynth-Studio.
|
||||
|
||||
Text-to-Music generation pipeline using ACE-Step 1.5 model.
|
||||
"""
|
||||
import re
|
||||
import torch
|
||||
import re, torch
|
||||
from typing import Optional, Dict, Any, List, Tuple
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
import math
|
||||
import random, math
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
@@ -39,7 +37,7 @@ class AceStepPipeline(BasePipeline):
|
||||
self.conditioner: AceStepConditionEncoder = None
|
||||
self.dit: AceStepDiTModel = None
|
||||
self.vae: AceStepVAE = None
|
||||
self.tokenizer_model: AceStepTokenizer = None # AceStepTokenizer (tokenizer + detokenizer)
|
||||
self.tokenizer_model: AceStepTokenizer = None
|
||||
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
@@ -65,7 +63,6 @@ class AceStepPipeline(BasePipeline):
|
||||
silence_latent_config: ModelConfig = ModelConfig(model_id="ACE-Step/Ace-Step1.5", origin_file_pattern="acestep-v15-turbo/silence_latent.pt"),
|
||||
vram_limit: float = None,
|
||||
):
|
||||
"""Load pipeline from pretrained checkpoints."""
|
||||
pipe = AceStepPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
@@ -102,7 +99,7 @@ class AceStepPipeline(BasePipeline):
|
||||
reference_audios: List[torch.Tensor] = None,
|
||||
# Source audio
|
||||
src_audio: torch.Tensor = None,
|
||||
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
|
||||
denoising_strength: float = 1.0, # denoising_strength = 1 - cover_noise_strength
|
||||
audio_cover_strength: float = 1.0,
|
||||
# Audio codes
|
||||
audio_code_string: Optional[str] = None,
|
||||
@@ -115,7 +112,7 @@ class AceStepPipeline(BasePipeline):
|
||||
bpm: Optional[int] = 100,
|
||||
keyscale: Optional[str] = "B minor",
|
||||
timesignature: Optional[str] = "4",
|
||||
vocal_language: Optional[str] = 'unknown',
|
||||
vocal_language: Optional[str] = "unknown",
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
@@ -126,10 +123,10 @@ class AceStepPipeline(BasePipeline):
|
||||
# Progress
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# 1. Scheduler
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps=num_inference_steps, denoising_strength=denoising_strength, shift=shift)
|
||||
|
||||
# 2. 三字典输入
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt, "positive": True}
|
||||
inputs_nega = {"positive": False}
|
||||
inputs_shared = {
|
||||
@@ -147,13 +144,12 @@ class AceStepPipeline(BasePipeline):
|
||||
"shift": shift,
|
||||
}
|
||||
|
||||
# 3. Unit 链执行
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(
|
||||
unit, self, inputs_shared, inputs_posi, inputs_nega
|
||||
)
|
||||
|
||||
# 4. Denoise loop
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
@@ -164,17 +160,17 @@ class AceStepPipeline(BasePipeline):
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id,
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
|
||||
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
inputs_shared["latents"] = self.step(
|
||||
self.scheduler, inpaint_mask=inputs_shared.get("denoise_mask", None), input_latents=inputs_shared.get("src_latents", None),
|
||||
progress_id=progress_id, noise_pred=noise_pred, **inputs_shared,
|
||||
)
|
||||
|
||||
# 5. VAE 解码
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
# DiT output is [B, T, 64] (channels-last), VAE expects [B, 64, T] (channels-first)
|
||||
latents = inputs_shared["latents"].transpose(1, 2)
|
||||
vae_output = self.vae.decode(latents)
|
||||
# VAE returns OobleckDecoderOutput with .sample attribute
|
||||
audio_output = vae_output.sample if hasattr(vae_output, 'sample') else vae_output
|
||||
audio_output = self.normalize_audio(audio_output, target_db=-1.0)
|
||||
audio_output = self.normalize_audio(vae_output, target_db=-1.0)
|
||||
audio = self.output_audio_format_check(audio_output)
|
||||
self.load_models_to_device([])
|
||||
return audio
|
||||
@@ -188,7 +184,9 @@ class AceStepPipeline(BasePipeline):
|
||||
return audio * gain
|
||||
|
||||
def switch_noncover_condition(self, inputs_shared, inputs_posi, inputs_nega, progress_id):
|
||||
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0 or inputs_shared.get("shared_noncover", None) is None:
|
||||
if inputs_shared["task_type"] != "cover" or inputs_shared["audio_cover_strength"] >= 1.0:
|
||||
return
|
||||
if inputs_shared.get("shared_noncover", None) is None:
|
||||
return
|
||||
cover_steps = int(len(self.scheduler.timesteps) * inputs_shared["audio_cover_strength"])
|
||||
if progress_id >= cover_steps:
|
||||
@@ -312,7 +310,10 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
||||
def process(self, pipe, reference_audios):
|
||||
if reference_audios is not None:
|
||||
pipe.load_models_to_device(['vae'])
|
||||
reference_audios = [self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device) for reference_audio in reference_audios]
|
||||
reference_audios = [
|
||||
self.process_reference_audio(reference_audio).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
for reference_audio in reference_audios
|
||||
]
|
||||
reference_latents, refer_audio_order_mask = self.infer_refer_latent(pipe, [reference_audios])
|
||||
else:
|
||||
reference_audios = [[torch.zeros(2, 30 * pipe.vae.sampling_rate).to(dtype=pipe.torch_dtype, device=pipe.device)]]
|
||||
@@ -357,7 +358,6 @@ class AceStepUnit_ReferenceAudioEmbedder(PipelineUnit):
|
||||
|
||||
|
||||
class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
@@ -378,7 +378,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
|
||||
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
|
||||
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(
|
||||
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device,
|
||||
)
|
||||
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
|
||||
if inputs_shared["task_type"] == "cover" and inputs_shared["audio_cover_strength"] < 1.0:
|
||||
hidden_states_noncover = AceStepUnit_PromptEmbedder().process(
|
||||
@@ -396,7 +398,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
|
||||
inputs_shared["posi_noncover"] = {"encoder_hidden_states": encoder_hidden_states_noncover, "encoder_attention_mask": encoder_attention_mask_noncover}
|
||||
if inputs_shared["cfg_scale"] != 1.0:
|
||||
inputs_shared["nega_noncover"] = {
|
||||
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device),
|
||||
"encoder_hidden_states": pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states_noncover).to(
|
||||
dtype=encoder_hidden_states_noncover.dtype, device=encoder_hidden_states_noncover.device,
|
||||
),
|
||||
"encoder_attention_mask": encoder_attention_mask_noncover,
|
||||
}
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
@@ -483,7 +487,7 @@ class AceStepUnit_ContextLatentBuilder(PipelineUnit):
|
||||
# use audio_cede_string to get src_latents.
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
code_ids = self._parse_audio_code_string(audio_code_string)
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer
|
||||
quantizer = pipe.tokenizer_model.tokenizer.quantizer.to(device=pipe.device)
|
||||
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
|
||||
codes = quantizer.get_codes_from_indices(indices)
|
||||
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
|
||||
|
||||
Reference in New Issue
Block a user