From ad91d416018da6cefa4b15ba7e826a45aa93a472 Mon Sep 17 00:00:00 2001 From: feng0w0 Date: Fri, 16 Jan 2026 10:28:24 +0800 Subject: [PATCH] [NPU]:Replace 'cuda' in the project with abstract interfaces --- diffsynth/core/__init__.py | 1 - diffsynth/core/npu_patch/__init__.py | 5 ----- .../core/npu_patch/npu_autocast_patch.py | 21 ------------------- diffsynth/models/longcat_video_dit.py | 12 +---------- 4 files changed, 1 insertion(+), 38 deletions(-) delete mode 100644 diffsynth/core/npu_patch/__init__.py delete mode 100644 diffsynth/core/npu_patch/npu_autocast_patch.py diff --git a/diffsynth/core/__init__.py b/diffsynth/core/__init__.py index 4d5f440..6c0a6c8 100644 --- a/diffsynth/core/__init__.py +++ b/diffsynth/core/__init__.py @@ -4,4 +4,3 @@ from .gradient import * from .loader import * from .vram import * from .device import * -from .npu_patch import * diff --git a/diffsynth/core/npu_patch/__init__.py b/diffsynth/core/npu_patch/__init__.py deleted file mode 100644 index eb1df93..0000000 --- a/diffsynth/core/npu_patch/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from diffsynth.core.device.npu_compatible_device import IS_NPU_AVAILABLE -from .npu_autocast_patch import npu_autocast_patch - -if IS_NPU_AVAILABLE: - npu_autocast_patch() diff --git a/diffsynth/core/npu_patch/npu_autocast_patch.py b/diffsynth/core/npu_patch/npu_autocast_patch.py deleted file mode 100644 index 08b1caf..0000000 --- a/diffsynth/core/npu_patch/npu_autocast_patch.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -from contextlib import contextmanager - - -def npu_autocast_patch_wrapper(func): - @contextmanager - def wrapper(*args, **kwargs): - flag = False - if "npu" in args or ("device_type" in kwargs and kwargs["device_type"] == "npu"): - if torch.float32 in args or ("dtype" in kwargs and kwargs["dtype"] == torch.float32): - flag = True - with func(*args, **kwargs) as ctx: - if flag: - torch.npu.set_autocast_enabled(True) - yield ctx - return wrapper - - -def npu_autocast_patch(): - torch.amp.autocast = npu_autocast_patch_wrapper(torch.amp.autocast) - torch.autocast = npu_autocast_patch_wrapper(torch.autocast) diff --git a/diffsynth/models/longcat_video_dit.py b/diffsynth/models/longcat_video_dit.py index ebcc9d0..dbe1c21 100644 --- a/diffsynth/models/longcat_video_dit.py +++ b/diffsynth/models/longcat_video_dit.py @@ -9,7 +9,7 @@ import numpy as np import torch.nn.functional as F from einops import rearrange, repeat from .wan_video_dit import flash_attention -from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type +from ..core.device.npu_compatible_device import get_device_type from ..core.gradient import gradient_checkpoint_forward @@ -375,8 +375,6 @@ class FinalLayer_FP32(nn.Module): T, _, _ = latent_shape with amp.autocast(get_device_type(), dtype=torch.float32): - if IS_NPU_AVAILABLE: - torch.npu.set_autocast_enabled(True) shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C] x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C) x = self.linear(x) @@ -587,8 +585,6 @@ class LongCatSingleStreamBlock(nn.Module): # compute modulation params in fp32 with amp.autocast(device_type=get_device_type(), dtype=torch.float32): - if IS_NPU_AVAILABLE: - torch.npu.set_autocast_enabled(True) shift_msa, scale_msa, gate_msa, \ shift_mlp, scale_mlp, gate_mlp = \ self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C] @@ -608,8 +604,6 @@ class LongCatSingleStreamBlock(nn.Module): x_s = attn_outputs with amp.autocast(device_type=get_device_type(), dtype=torch.float32): - if IS_NPU_AVAILABLE: - torch.npu.set_autocast_enabled(True) x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -623,8 +617,6 @@ class LongCatSingleStreamBlock(nn.Module): x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C) x_s = self.ffn(x_m) with amp.autocast(device_type=get_device_type(), dtype=torch.float32): - if IS_NPU_AVAILABLE: - torch.npu.set_autocast_enabled(True) x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C] x = x.to(x_dtype) @@ -807,8 +799,6 @@ class LongCatVideoTransformer3DModel(torch.nn.Module): hidden_states = self.x_embedder(hidden_states) # [B, N, C] with amp.autocast(device_type=get_device_type(), dtype=torch.float32): - if IS_NPU_AVAILABLE: - torch.npu.set_autocast_enabled(True) t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t] encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]