mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
@@ -9,6 +9,7 @@ import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from .wan_video_dit import flash_attention
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
@@ -373,7 +374,9 @@ class FinalLayer_FP32(nn.Module):
|
||||
B, N, C = x.shape
|
||||
T, _, _ = latent_shape
|
||||
|
||||
with amp.autocast('cuda', dtype=torch.float32):
|
||||
with amp.autocast(get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
||||
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
||||
x = self.linear(x)
|
||||
@@ -583,7 +586,9 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
|
||||
|
||||
# compute modulation params in fp32
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
shift_msa, scale_msa, gate_msa, \
|
||||
shift_mlp, scale_mlp, gate_mlp = \
|
||||
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
||||
@@ -602,7 +607,9 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
else:
|
||||
x_s = attn_outputs
|
||||
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
@@ -615,7 +622,9 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
# ffn with modulation
|
||||
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
||||
x_s = self.ffn(x_m)
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
@@ -797,7 +806,9 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
||||
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
if IS_NPU_AVAILABLE:
|
||||
torch.npu.set_autocast_enabled(True)
|
||||
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
||||
|
||||
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
||||
|
||||
Reference in New Issue
Block a user