mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 06:46:13 +00:00
low_vram
This commit is contained in:
@@ -22,7 +22,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
@@ -240,3 +240,9 @@ class AceStepVAE(nn.Module):
|
||||
"""Full round-trip: encode → decode."""
|
||||
z = self.encode(sample)
|
||||
return self.decoder(z)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization from all conv layers (for export/inference)."""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
|
||||
remove_weight_norm(module)
|
||||
|
||||
Reference in New Issue
Block a user