fix RMSNorm precision

This commit is contained in:
Artiprocher
2026-01-14 16:29:43 +08:00
parent a236a17f17
commit acba342a63

View File

@@ -6,7 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn import RMSNorm
from .general_modules import RMSNorm
from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
from ..core.gradient import gradient_checkpoint_forward