[NPU]:Replace 'cuda' in the project with abstract interfaces

This commit is contained in:
feng0w0
2026-01-16 10:28:24 +08:00
parent dce77ec4d1
commit ad91d41601
4 changed files with 1 additions and 38 deletions

View File

@@ -4,4 +4,3 @@ from .gradient import *
from .loader import *
from .vram import *
from .device import *
from .npu_patch import *

View File

@@ -1,5 +0,0 @@
from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE
from .npu_autocast_patch import npu_autocast_patch
if IS_NPU_AVAILABLE:
npu_autocast_patch()

View File

@@ -1,21 +0,0 @@
import torch
from contextlib import contextmanager
def npu_autocast_patch_wrapper(func):
@contextmanager
def wrapper(*args, **kwargs):
flag = False
if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"):
if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32):
flag = True
with func(*args, **kwargs) as ctx:
if flag:
torch.npu.set_autocast_enabled(True)
yield ctx
return wrapper
def npu_autocast_patch():
torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast)
torch.autocast = npu_autocast_patch_wrapper(torch.autocast)