upgrade to rwkv 0.8.20
This commit is contained in:
1
backend-python/rwkv_pip/utils.py
vendored
1
backend-python/rwkv_pip/utils.py
vendored
@@ -81,6 +81,7 @@ class PIPELINE:
|
||||
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
||||
probs = F.softmax(logits.float(), dim=-1)
|
||||
top_k = int(top_k)
|
||||
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
|
||||
if probs.device.type in ["cpu", "privateuseone"]:
|
||||
probs = probs.cpu().numpy()
|
||||
sorted_ids = np.argsort(probs)
|
||||
|
||||
Reference in New Issue
Block a user