This commit is contained in:
mi804
2026-04-22 12:47:38 +08:00
parent f5a3201d42
commit b0680ef711
15 changed files with 523 additions and 14 deletions

View File

@@ -295,6 +295,43 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
# ACE-Step module maps
"diffsynth.models.ace_step_dit.AceStepDiTModel": {
"diffsynth.models.ace_step_dit.AceStepDiTLayer": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ace_step_conditioner.AceStepConditionEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ace_step_text_encoder.AceStepTextEncoder": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ace_step_vae.AceStepVAE": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
"diffsynth.models.ace_step_vae.Snake1d": "diffsynth.core.vram.layers.AutoWrappedModule",
},
"diffsynth.models.ace_step_tokenizer.AceStepTokenizer": {
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
"transformers.models.qwen3.modeling_qwen3.Qwen3RotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
},
}
def QwenImageTextEncoder_Module_Map_Updater():

View File

@@ -522,7 +522,7 @@ class AceStepDiTLayer(nn.Module):
# Extract scale-shift parameters for adaptive layer norm from timestep embeddings
# 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb
self.scale_shift_table.to(temb.device) + temb
).chunk(6, dim=1)
# Step 1: Self-attention with adaptive layer norm (AdaLN)
@@ -889,7 +889,7 @@ class AceStepDiTModel(nn.Module):
return hidden_states
# Extract scale-shift parameters for adaptive output normalization
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)

View File

@@ -594,7 +594,7 @@ class AudioTokenDetokenizer(nn.Module):
x = self.embed_tokens(x)
x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
special_tokens = self.special_tokens.expand(B, T, -1, -1)
x = x + special_tokens
x = x + special_tokens.to(x.device)
x = rearrange(x, "b t p c -> (b t) p c")
cache_position = torch.arange(0, x.shape[1], device=x.device)

View File

@@ -22,7 +22,7 @@ from typing import Optional
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from torch.nn.utils import weight_norm, remove_weight_norm
class Snake1d(nn.Module):
@@ -240,3 +240,9 @@ class AceStepVAE(nn.Module):
"""Full round-trip: encode → decode."""
z = self.encode(sample)
return self.decoder(z)
def remove_weight_norm(self):
"""Remove weight normalization from all conv layers (for export/inference)."""
for module in self.modules():
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
remove_weight_norm(module)

View File

@@ -69,6 +69,7 @@ class AceStepPipeline(BasePipeline):
pipe.conditioner = model_pool.fetch_model("ace_step_conditioner")
pipe.dit = model_pool.fetch_model("ace_step_dit")
pipe.vae = model_pool.fetch_model("ace_step_vae")
pipe.vae.remove_weight_norm()
pipe.tokenizer_model = model_pool.fetch_model("ace_step_tokenizer")
if text_tokenizer_config is not None:
@@ -372,8 +373,9 @@ class AceStepUnit_ConditionEmbedder(PipelineUnit):
)
inputs_posi["encoder_hidden_states"] = encoder_hidden_states
inputs_posi["encoder_attention_mask"] = encoder_attention_mask
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
if inputs_shared["cfg_scale"] != 1.0:
inputs_nega["encoder_hidden_states"] = pipe.conditioner.null_condition_emb.expand_as(encoder_hidden_states).to(dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
inputs_nega["encoder_attention_mask"] = encoder_attention_mask
return inputs_shared, inputs_posi, inputs_nega
@@ -468,10 +470,15 @@ class AceStepUnit_AudioCodeDecoder(PipelineUnit):
return {"lm_hints": None}
pipe.load_models_to_device(["tokenizer_model"])
indices = torch.tensor(code_ids, device=pipe.device, dtype=torch.long)
indices = indices.unsqueeze(0).unsqueeze(-1) # [1, N, 1]
quantized = pipe.tokenizer_model.tokenizer.quantizer.get_output_from_indices(indices).to(pipe.torch_dtype) # [1, N, 2048]
lm_hints = pipe.tokenizer_model.detokenizer(quantized) # [1, N*5, 64]
quantizer = pipe.tokenizer_model.tokenizer.quantizer
detokenizer = pipe.tokenizer_model.detokenizer
indices = torch.tensor(code_ids, device=quantizer.codebooks.device, dtype=torch.long).unsqueeze(0).unsqueeze(-1)
codes = quantizer.get_codes_from_indices(indices)
quantized = codes.sum(dim=0).to(pipe.torch_dtype).to(pipe.device)
quantized = quantizer.project_out(quantized)
lm_hints = detokenizer(quantized).to(pipe.device)
return {"lm_hints": lm_hints}