From 3b662da31e49e2cf9196d8608f7ba0c6c71875ec Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Fri, 9 Jan 2026 18:11:40 +0800 Subject: [PATCH] [model][NPU]:Wan model rope use torch.complex64 in NPU --- diffsynth/utils/xfuser/xdit_context_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index b7fa72d..d365cfe 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -5,7 +5,7 @@ from xfuser.core.distributed import (get_sequence_parallel_rank, get_sequence_parallel_world_size, get_sp_group) from xfuser.core.long_ctx_attention import xFuserLongContextAttention -from ...core.device import parse_nccl_backend, parse_device_type +from ...core.device import parse_nccl_backend, parse_device_type, IS_NPU_AVAILABLE def initialize_usp(device_type): @@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads): sp_rank = get_sequence_parallel_rank() freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] - + freqs_rank = freqs_rank.to(torch.complex64) if IS_NPU_AVAILABLE else freqs_rank x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) return x_out.to(x.dtype)