diff --git a/diffsynth/models/z_image_dit.py b/diffsynth/models/z_image_dit.py index 4c5622c..d20ec51 100644 --- a/diffsynth/models/z_image_dit.py +++ b/diffsynth/models/z_image_dit.py @@ -276,9 +276,9 @@ class RopeEmbedder: for i in range(len(self.axes_dims)): index = ids[:, i] if IS_NPU_AVAILABLE: - result.append(self.freqs_cis[i][index]) - else: result.append(torch.index_select(self.freqs_cis[i], 0, index)) + else: + result.append(self.freqs_cis[i][index]) return torch.cat(result, dim=-1)