mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
[model][NPU]:Wan model rope use torch.complex64 in NPU
This commit is contained in:
@@ -5,6 +5,8 @@ import math
|
|||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .wan_video_camera_controller import SimpleAdapter
|
from .wan_video_camera_controller import SimpleAdapter
|
||||||
|
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import flash_attn_interface
|
import flash_attn_interface
|
||||||
FLASH_ATTN_3_AVAILABLE = True
|
FLASH_ATTN_3_AVAILABLE = True
|
||||||
@@ -92,6 +94,7 @@ def rope_apply(x, freqs, num_heads):
|
|||||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||||
|
freqs = freqs.to(torch.complex64) if IS_NPU_AVAILABLE else freqs
|
||||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||||
return x_out.to(x.dtype)
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user