Merge pull request #1191 from Feng0w0/wan_rope

[model][NPU]:Wan model rope use torch.complex64 in NPU
This commit is contained in:
Zhongjie Duan
2026-01-20 10:05:22 +08:00
committed by GitHub
5 changed files with 6 additions and 4 deletions

View File

@@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads):
sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype)