21 lines
986 B
Python
21 lines
986 B
Python
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
from fla.modules.convolution import (ImplicitLongConvolution, LongConvolution,
|
||
|
ShortConvolution)
|
||
|
from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss
|
||
|
from fla.modules.fused_norm_gate import (FusedLayerNormSwishGate,
|
||
|
FusedLayerNormSwishGateLinear,
|
||
|
FusedRMSNormSwishGate,
|
||
|
FusedRMSNormSwishGateLinear)
|
||
|
from fla.modules.layernorm import (LayerNorm, LayerNormLinear, RMSNorm,
|
||
|
RMSNormLinear)
|
||
|
from fla.modules.rotary import RotaryEmbedding
|
||
|
|
||
|
__all__ = [
|
||
|
'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution',
|
||
|
'FusedCrossEntropyLoss',
|
||
|
'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear',
|
||
|
'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear',
|
||
|
'RotaryEmbedding'
|
||
|
]
|