Merge branch 'main' into cuda_replace

This commit is contained in:
Zhongjie Duan
2026-01-20 10:12:31 +08:00
committed by GitHub
57 changed files with 996 additions and 116 deletions

View File

@@ -6,7 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn import RMSNorm
from .general_modules import RMSNorm
from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
from ..core.gradient import gradient_checkpoint_forward