rwkv.cpp(ggml) support

This commit is contained in:
josc146
2023-12-12 20:29:55 +08:00
parent 6e29f97881
commit b14fbc29b7
26 changed files with 1234 additions and 102 deletions

View File

@@ -78,12 +78,22 @@ class PIPELINE:
def decode(self, x):
return self.tokenizer.decode(x)
def np_softmax(self, x: np.ndarray, axis: int):
x -= x.max(axis=axis, keepdims=True)
e: np.ndarray = np.exp(x)
return e / e.sum(axis=axis, keepdims=True)
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
probs = F.softmax(logits.float(), dim=-1)
np_logits = type(logits) == np.ndarray
if np_logits:
probs = self.np_softmax(logits, axis=-1)
else:
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
if probs.device.type in ["cpu", "privateuseone"]:
probs = probs.cpu().numpy()
if np_logits or probs.device.type in ["cpu", "privateuseone"]:
if not np_logits:
probs = probs.cpu().numpy()
sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs)