[NPU]:Replace 'cuda' in the project with abstract interfaces

This commit is contained in:
feng0w0
2026-01-16 10:28:24 +08:00
parent dce77ec4d1
commit ad91d41601
4 changed files with 1 additions and 38 deletions

View File

@@ -4,4 +4,3 @@ from .gradient import *
from .loader import *
from .vram import *
from .device import *
from .npu_patch import *

View File

@@ -1,5 +0,0 @@
from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE
from .npu_autocast_patch import npu_autocast_patch
if IS_NPU_AVAILABLE:
npu_autocast_patch()

View File

@@ -1,21 +0,0 @@
import torch
from contextlib import contextmanager
def npu_autocast_patch_wrapper(func):
@contextmanager
def wrapper(*args, **kwargs):
flag = False
if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"):
if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32):
flag = True
with func(*args, **kwargs) as ctx:
if flag:
torch.npu.set_autocast_enabled(True)
yield ctx
return wrapper
def npu_autocast_patch():
torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast)
torch.autocast = npu_autocast_patch_wrapper(torch.autocast)

View File

@@ -9,7 +9,7 @@ import numpy as np
import torch.nn.functional as F
from einops import rearrange, repeat
from .wan_video_dit import flash_attention
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
from ..core.device.npu_compatible_device import get_device_type
from ..core.gradient import gradient_checkpoint_forward
@@ -375,8 +375,6 @@ class FinalLayer_FP32(nn.Module):
T, _, _ = latent_shape
with amp.autocast(get_device_type(), dtype=torch.float32):
if IS_NPU_AVAILABLE:
torch.npu.set_autocast_enabled(True)
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
x = self.linear(x)
@@ -587,8 +585,6 @@ class LongCatSingleStreamBlock(nn.Module):
# compute modulation params in fp32
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
if IS_NPU_AVAILABLE:
torch.npu.set_autocast_enabled(True)
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
@@ -608,8 +604,6 @@ class LongCatSingleStreamBlock(nn.Module):
x_s = attn_outputs
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
if IS_NPU_AVAILABLE:
torch.npu.set_autocast_enabled(True)
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
@@ -623,8 +617,6 @@ class LongCatSingleStreamBlock(nn.Module):
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
x_s = self.ffn(x_m)
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
if IS_NPU_AVAILABLE:
torch.npu.set_autocast_enabled(True)
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
x = x.to(x_dtype)
@@ -807,8 +799,6 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
if IS_NPU_AVAILABLE:
torch.npu.set_autocast_enabled(True)
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]