Merge pull request #1191 from Feng0w0/wan_rope

[model][NPU]:Wan model rope use torch.complex64 in NPU
This commit is contained in:
Zhongjie Duan
2026-01-20 10:05:22 +08:00
committed by GitHub
5 changed files with 6 additions and 4 deletions

View File

@@ -1,2 +1,2 @@
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
from .npu_compatible_device import IS_NPU_AVAILABLE from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE

View File

@@ -5,6 +5,7 @@ import math
from typing import Tuple, Optional from typing import Tuple, Optional
from einops import rearrange from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter from .wan_video_camera_controller import SimpleAdapter
try: try:
import flash_attn_interface import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True FLASH_ATTN_3_AVAILABLE = True
@@ -92,6 +93,7 @@ def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads) x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape( x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2)) x.shape[0], x.shape[1], x.shape[2], -1, 2))
freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs
x_out = torch.view_as_real(x_out * freqs).flatten(2) x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype) return x_out.to(x.dtype)

View File

@@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads):
sp_rank = get_sequence_parallel_rank() sp_rank = get_sequence_parallel_rank()
freqs = pad_freqs(freqs, s_per_rank * sp_size) freqs = pad_freqs(freqs, s_per_rank * sp_size)
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :] freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2) x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
return x_out.to(x.dtype) return x_out.to(x.dtype)

View File

@@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5)
``` ```
### Training ### Training
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_scripts`, for example `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`. NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.
In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models. In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models.

View File

@@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5)
``` ```
### 训练 ### 训练
当前已为每类模型添加NPU的启动脚本样例脚本存放在`examples/xxx/special/npu_scripts`目录下,例如 `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh` 当前已为每类模型添加NPU的启动脚本样例脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`
在NPU训练脚本中添加了可以优化性能的NPU特有环境变量并针对特定模型开启了相关参数。 在NPU训练脚本中添加了可以优化性能的NPU特有环境变量并针对特定模型开启了相关参数。