upgrade to rwkv 0.8.20

This commit is contained in:
josc146
2023-11-03 23:27:14 +08:00
parent 35e92d2aef
commit 1f81a1e5a8
6 changed files with 188 additions and 359 deletions

View File

@@ -81,6 +81,7 @@ class PIPELINE:
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
if probs.device.type in ["cpu", "privateuseone"]:
probs = probs.cpu().numpy()
sorted_ids = np.argsort(probs)