mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
[NPU]:Replace 'cuda' in the project with abstract interfaces
This commit is contained in:
@@ -8,7 +8,7 @@ from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from torch.nn import RMSNorm
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class TimestepEmbedder(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
@@ -105,7 +105,7 @@ class Attention(torch.nn.Module):
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
|
||||
Reference in New Issue
Block a user