mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
@@ -5,7 +5,6 @@ import math
|
|||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from .wan_video_camera_controller import SimpleAdapter
|
from .wan_video_camera_controller import SimpleAdapter
|
||||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import flash_attn_interface
|
import flash_attn_interface
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from xfuser.core.distributed import (get_sequence_parallel_rank,
|
|||||||
get_sequence_parallel_world_size,
|
get_sequence_parallel_world_size,
|
||||||
get_sp_group)
|
get_sp_group)
|
||||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||||
from ...core.device import parse_nccl_backend, parse_device_type, IS_NPU_AVAILABLE
|
from ...core.device import parse_nccl_backend, parse_device_type
|
||||||
|
|
||||||
|
|
||||||
def initialize_usp(device_type):
|
def initialize_usp(device_type):
|
||||||
@@ -50,6 +50,7 @@ def rope_apply(x, freqs, num_heads):
|
|||||||
sp_rank = get_sequence_parallel_rank()
|
sp_rank = get_sequence_parallel_rank()
|
||||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||||
|
|
||||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||||
return x_out.to(x.dtype)
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user