[bugfix][NPU]:Fix bug that correctly obtains device type

This commit is contained in:
feng0w0
2026-01-23 10:45:03 +08:00
parent ffb7a138f7
commit 35e0776022
2 changed files with 2 additions and 2 deletions

View File

@@ -93,7 +93,7 @@ def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs
freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)

View File

@@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads):
sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device.type == "npu" else freqs_rank
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype)