Merge pull request #1191 from Feng0w0/wan_rope

[model][NPU]:Wan model rope use torch.complex64 in NPU
This commit is contained in:
Zhongjie Duan
2026-01-20 10:05:22 +08:00
committed by GitHub
5 changed files with 6 additions and 4 deletions

View File

@@ -5,6 +5,7 @@ import math
from typing import Tuple, Optional
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
@@ -92,6 +93,7 @@ def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)