mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
5
diffsynth/core/npu_patch/__init__.py
Normal file
5
diffsynth/core/npu_patch/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE
|
||||
from .npu_autocast_patch import npu_autocast_patch
|
||||
|
||||
if IS_NPU_AVAILABLE:
|
||||
npu_autocast_patch()
|
||||
21
diffsynth/core/npu_patch/npu_autocast_patch.py
Normal file
21
diffsynth/core/npu_patch/npu_autocast_patch.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
def npu_autocast_patch_wrapper(func):
|
||||
@contextmanager
|
||||
def wrapper(*args, **kwargs):
|
||||
flag = False
|
||||
if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"):
|
||||
if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32):
|
||||
flag = True
|
||||
with func(*args, **kwargs) as ctx:
|
||||
if flag:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
yield ctx
|
||||
return wrapper
|
||||
|
||||
|
||||
def npu_autocast_patch():
|
||||
torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast)
|
||||
torch.autocast = npu_autocast_patch_wrapper(torch.autocast)
|
||||
Reference in New Issue
Block a user