mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
low_vram
This commit is contained in:
@@ -522,7 +522,7 @@ class AceStepDiTLayer(nn.Module):
|
||||
# Extract scale-shift parameters for adaptive layer norm from timestep embeddings
|
||||
# 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb
|
||||
self.scale_shift_table.to(temb.device) + temb
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# Step 1: Self-attention with adaptive layer norm (AdaLN)
|
||||
@@ -889,7 +889,7 @@ class AceStepDiTModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
# Extract scale-shift parameters for adaptive output normalization
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
shift = shift.to(hidden_states.device)
|
||||
scale = scale.to(hidden_states.device)
|
||||
|
||||
|
||||
@@ -594,7 +594,7 @@ class AudioTokenDetokenizer(nn.Module):
|
||||
x = self.embed_tokens(x)
|
||||
x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
|
||||
special_tokens = self.special_tokens.expand(B, T, -1, -1)
|
||||
x = x + special_tokens
|
||||
x = x + special_tokens.to(x.device)
|
||||
x = rearrange(x, "b t p c -> (b t) p c")
|
||||
|
||||
cache_position = torch.arange(0, x.shape[1], device=x.device)
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
@@ -240,3 +240,9 @@ class AceStepVAE(nn.Module):
|
||||
"""Full round-trip: encode → decode."""
|
||||
z = self.encode(sample)
|
||||
return self.decoder(z)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization from all conv layers (for export/inference)."""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Conv1d) or isinstance(module, nn.ConvTranspose1d):
|
||||
remove_weight_norm(module)
|
||||
|
||||
Reference in New Issue
Block a user