This commit is contained in:
mi804
2026-04-22 12:47:38 +08:00
parent f5a3201d42
commit b0680ef711
15 changed files with 523 additions and 14 deletions

View File

@@ -22,7 +22,7 @@ from typing import Optional
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from torch.nn.utils import weight_norm, remove_weight_norm
class Snake1d(nn.Module):
@@ -240,3 +240,9 @@ class AceStepVAE(nn.Module):
"""Full round-trip: encode → decode."""
z = self.encode(sample)
return self.decoder(z)
def remove_weight_norm(self):
"""Remove weight normalization from all conv layers (for export/inference)."""
for module in self.modules():
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
remove_weight_norm(module)