[model][NPU]:Z-image model support NPU

This commit is contained in:
feng0w0
2026-01-07 11:31:22 +08:00
parent 8f1d10fb43
commit 3ee5f53a36

View File

@@ -8,6 +8,7 @@ from torch.nn.utils.rnn import pad_sequence
from torch.nn import RMSNorm
from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
from ..core.gradient import gradient_checkpoint_forward
@@ -274,7 +275,10 @@ class RopeEmbedder:
result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
if IS_NPU_AVAILABLE:
result.append(self.freqs_cis[i][index])
else:
result.append(torch.index_select(self.freqs_cis[i], 0, index))
return torch.cat(result, dim=-1)