diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index 7664fc5..4c5622c 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -8,6 +8,7 @@ from torch.nn.utils.rnn import pad_sequence from torch.nn import RMSNorm from ..core.attention import attention_forward +from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE from ..core.gradient import gradient_checkpoint_forward @@ -274,7 +275,10 @@ class RopeEmbedder: result = [] for i in range(len(self.axes_dims)): index = ids[:, i] - result.append(self.freqs_cis[i][index]) + if IS_NPU_AVAILABLE: + result.append(self.freqs_cis[i][index]) + else: + result.append(torch.index_select(self.freqs_cis[i], 0, index)) return torch.cat(result, dim=-1)