diff --git a/diffsynth/core/npu_patch/npu_fused_operator.py b/diffsynth/core/npu_patch/npu_fused_operator.py index 5b28eea..7166041 100644 --- a/diffsynth/core/npu_patch/npu_fused_operator.py +++ b/diffsynth/core/npu_patch/npu_fused_operator.py @@ -27,4 +27,4 @@ def rotary_emb_Zimage_npu(self, x_in: torch.Tensor, freqs_cis: torch.Tensor): cos, sin = torch.chunk(torch.view_as_real(freqs_cis), 2, dim=-1) cos = cos.expand(-1, -1, -1, -1, 2).flatten(-2) sin = sin.expand(-1, -1, -1, -1, 2).flatten(-2) - return torch_npu.npu_rotary_mul(x_in, cos, sin).to(x_in) \ No newline at end of file + return torch_npu.npu_rotary_mul(x_in, cos, sin, rotary_mode="interleave").to(x_in) \ No newline at end of file