mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 02:38:10 +00:00
Merge pull request #1191 from Feng0w0/wan_rope
[model][NPU]:Wan model rope use torch.complex64 in NPU
This commit is contained in:
@@ -1,2 +1,2 @@
|
|||||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
|
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
|
||||||
from .npu_compatible_device import IS_NPU_AVAILABLE
|
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import math
|
|||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .wan_video_camera_controller import SimpleAdapter
|
from .wan_video_camera_controller import SimpleAdapter
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import flash_attn_interface
|
import flash_attn_interface
|
||||||
FLASH_ATTN_3_AVAILABLE = True
|
FLASH_ATTN_3_AVAILABLE = True
|
||||||
@@ -92,6 +93,7 @@ def rope_apply(x, freqs, num_heads):
|
|||||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||||
|
freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs
|
||||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||||
return x_out.to(x.dtype)
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ def rope_apply(x, freqs, num_heads):
|
|||||||
sp_rank = get_sequence_parallel_rank()
|
sp_rank = get_sequence_parallel_rank()
|
||||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||||
|
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank
|
||||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||||
return x_out.to(x.dtype)
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5)
|
|||||||
```
|
```
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_scripts`, for example `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`.
|
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.
|
||||||
|
|
||||||
In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models.
|
In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models.
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ save_video(video, "video.mp4", fps=15, quality=5)
|
|||||||
```
|
```
|
||||||
|
|
||||||
### 训练
|
### 训练
|
||||||
当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_scripts`目录下,例如 `examples/wanvideo/model_training/special/npu_scripts/Wan2.2-T2V-A14B-NPU.sh`。
|
当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。
|
||||||
|
|
||||||
在NPU训练脚本中,添加了可以优化性能的NPU特有环境变量,并针对特定模型开启了相关参数。
|
在NPU训练脚本中,添加了可以优化性能的NPU特有环境变量,并针对特定模型开启了相关参数。
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user