[NPU]:Replace 'cuda' in the project with abstract interfaces

This commit is contained in:
feng0w0
2026-01-15 20:33:01 +08:00
parent f4d06ce3fc
commit 209a350c0f
18 changed files with 86 additions and 38 deletions

View File

@@ -0,0 +1,5 @@
from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE
from .npu_autocast_patch import npu_autocast_patch
if IS_NPU_AVAILABLE:
npu_autocast_patch()

View File

@@ -0,0 +1,21 @@
import torch
from contextlib import contextmanager
def npu_autocast_patch_wrapper(func):
@contextmanager
def wrapper(*args, **kwargs):
flag = False
if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"):
if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32):
flag = True
with func(*args, **kwargs) as ctx:
if flag:
torch.npu.set_autocast_enabled(True)
yield ctx
return wrapper
def npu_autocast_patch():
torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast)
torch.autocast = npu_autocast_patch_wrapper(torch.autocast)