This commit is contained in:
josc146
2023-12-14 18:37:07 +08:00
parent 01c95f5bc4
commit 0ddd2e9fea
16 changed files with 155 additions and 34 deletions

View File

@@ -84,6 +84,8 @@ class PIPELINE:
return e / e.sum(axis=axis, keepdims=True)
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
if type(logits) == list:
logits = np.array(logits)
np_logits = type(logits) == np.ndarray
if np_logits:
probs = self.np_softmax(logits, axis=-1)