[NPU]:Support USP feature in NPU

This commit is contained in:
feng0w0
2026-01-21 10:38:27 +08:00
parent d879d66c62
commit b3cc652dea

View File

@@ -33,14 +33,16 @@ def sinusoidal_embedding_1d(dim, position):
def pad_freqs(original_tensor, target_len): def pad_freqs(original_tensor, target_len):
seq_len, s1, s2 = original_tensor.shape seq_len, s1, s2 = original_tensor.shape
pad_size = target_len - seq_len pad_size = target_len - seq_len
original_tensor_device = original_tensor.device
if original_tensor.device == "npu":
original_tensor = original_tensor.cpu()
padding_tensor = torch.ones( padding_tensor = torch.ones(
pad_size, pad_size,
s1, s1,
s2, s2,
dtype=original_tensor.dtype, dtype=original_tensor.dtype,
device='cpu') device=original_tensor.device)
original_tensor_device = original_tensor.device padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device)
padded_tensor = torch.cat([original_tensor.cpu(), padding_tensor], dim=0).to(device=original_tensor_device)
return padded_tensor return padded_tensor
def rope_apply(x, freqs, num_heads): def rope_apply(x, freqs, num_heads):