improve dml mode performance (20% faster, https://github.com/BlinkDL/ChatRWKV/pull/181)
This commit is contained in:
parent
f86b7f1f08
commit
14b90bb36b
4
backend-python/rwkv_pip/utils.py
vendored
4
backend-python/rwkv_pip/utils.py
vendored
@ -81,8 +81,8 @@ class PIPELINE:
|
|||||||
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
||||||
probs = F.softmax(logits.float(), dim=-1)
|
probs = F.softmax(logits.float(), dim=-1)
|
||||||
top_k = int(top_k)
|
top_k = int(top_k)
|
||||||
if probs.device == torch.device("cpu"):
|
if probs.device.type in ["cpu", "privateuseone"]:
|
||||||
probs = probs.numpy()
|
probs = probs.cpu().numpy()
|
||||||
sorted_ids = np.argsort(probs)
|
sorted_ids = np.argsort(probs)
|
||||||
sorted_probs = probs[sorted_ids][::-1]
|
sorted_probs = probs[sorted_ids][::-1]
|
||||||
cumulative_probs = np.cumsum(sorted_probs)
|
cumulative_probs = np.cumsum(sorted_probs)
|
||||||
|
Loading…
Reference in New Issue
Block a user