diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index 94b92c7..6f712fd 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -33,14 +33,16 @@ def sinusoidal_embedding_1d(dim, position): def pad_freqs(original_tensor, target_len): seq_len, s1, s2 = original_tensor.shape pad_size = target_len - seq_len + original_tensor_device = original_tensor.device + if original_tensor.device == "npu": + original_tensor = original_tensor.cpu() padding_tensor = torch.ones( pad_size, s1, s2, dtype=original_tensor.dtype, - device='cpu') - original_tensor_device = original_tensor.device - padded_tensor = torch.cat([original_tensor.cpu(), padding_tensor], dim=0).to(device=original_tensor_device) + device=original_tensor.device) + padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device) return padded_tensor def rope_apply(x, freqs, num_heads):