improve dml mode performance (20% faster, https://github.com/BlinkDL/ChatRWKV/pull/181)

This commit is contained in:
josc146 2023-10-30 20:24:57 +08:00
parent f86b7f1f08
commit 14b90bb36b

View File

@ -81,8 +81,8 @@ class PIPELINE:
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0): def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
probs = F.softmax(logits.float(), dim=-1) probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k) top_k = int(top_k)
if probs.device == torch.device("cpu"): if probs.device.type in ["cpu", "privateuseone"]:
probs = probs.numpy() probs = probs.cpu().numpy()
sorted_ids = np.argsort(probs) sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1] sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs) cumulative_probs = np.cumsum(sorted_probs)