mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from .wan_video_dit import flash_attention
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
@@ -375,8 +375,6 @@ class FinalLayer_FP32(nn.Module):
|
||||
T, _, _ = latent_shape
|
||||
|
||||
with amp.autocast(get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
||||
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
||||
x = self.linear(x)
|
||||
@@ -587,8 +585,6 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
|
||||
# compute modulation params in fp32
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
shift_msa, scale_msa, gate_msa, \
|
||||
shift_mlp, scale_mlp, gate_mlp = \
|
||||
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
||||
@@ -608,8 +604,6 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
x_s = attn_outputs
|
||||
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
@@ -623,8 +617,6 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
||||
x_s = self.ffn(x_m)
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
@@ -807,8 +799,6 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
|
||||
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
||||
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
||||
|
||||
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
||||
|
||||
Reference in New Issue
Block a user