This commit is contained in:
mi804
2026-04-22 12:47:38 +08:00
parent f5a3201d42
commit b0680ef711
15 changed files with 523 additions and 14 deletions

View File

@@ -69,6 +69,7 @@ class AceStepPipeline(BasePipeline):
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
pipe.dit = model_pool.fetch_model("ace_step_dit")
pipe.vae = model_pool.fetch_model("ace_step_vae")
pipe.vae.remove_weight_norm()
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
if text_tokenizer_config is not None:
@@ -372,8 +373,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
)
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
return inputs_shared, inputs_posi, inputs_nega
@@ -468,10 +470,15 @@ class AceStepUnit_AudioCodeDecoder(PipelineUnit):
return {"lm_hints": None}
pipe.load_models_to_device(["tokenizer_model"])
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
quantized = pipe.tokenizer_model.tokenizer.quantizer.get_output_from_indices(indices).to(pipe.torch_dtype) # [1, N, 2048]
lm_hints = pipe.tokenizer_model.detokenizer(quantized) # [1, N*5, 64]
quantizer = pipe.tokenizer_model.tokenizer.quantizer
detokenizer = pipe.tokenizer_model.detokenizer
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
codes = quantizer.get_codes_from_indices(indices)
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
quantized = quantizer.project_out(quantized)
lm_hints = detokenizer(quantized).to(pipe.device)
return {"lm_hints": lm_hints}