mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge branch 'modelscope:main' into wan_rope
This commit is contained in:
@@ -8,6 +8,7 @@ from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from torch.nn import RMSNorm
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
@@ -315,7 +316,10 @@ class RopeEmbedder:
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = ids[:, i]
|
||||
result.append(self.freqs_cis[i][index])
|
||||
if IS_NPU_AVAILABLE:
|
||||
result.append(torch.index_select(self.freqs_cis[i], 0, index))
|
||||
else:
|
||||
result.append(self.freqs_cis[i][index])
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user