From 14b90bb36b01522a53619e236c1fa08c09356944 Mon Sep 17 00:00:00 2001 From: josc146 Date: Mon, 30 Oct 2023 20:24:57 +0800 Subject: [PATCH] improve dml mode performance (20% faster, https://github.com/BlinkDL/ChatRWKV/pull/181) --- backend-python/rwkv_pip/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend-python/rwkv_pip/utils.py b/backend-python/rwkv_pip/utils.py index 8da5741..36e165d 100644 --- a/backend-python/rwkv_pip/utils.py +++ b/backend-python/rwkv_pip/utils.py @@ -81,8 +81,8 @@ class PIPELINE: def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0): probs = F.softmax(logits.float(), dim=-1) top_k = int(top_k) - if probs.device == torch.device("cpu"): - probs = probs.numpy() + if probs.device.type in ["cpu", "privateuseone"]: + probs = probs.cpu().numpy() sorted_ids = np.argsort(probs) sorted_probs = probs[sorted_ids][::-1] cumulative_probs = np.cumsum(sorted_probs)