add usage
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import pathlib
|
||||
import copy
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
from utils.log import quick_log
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -71,7 +71,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
|
||||
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
|
||||
)
|
||||
logits = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
|
||||
logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
|
||||
try:
|
||||
state_cache.add_state(
|
||||
state_cache.AddStateBody(
|
||||
@@ -92,6 +92,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
|
||||
def run_rnn(self, _tokens: List[str], newline_adj: int = 0):
|
||||
tokens = [int(x) for x in _tokens]
|
||||
token_len = len(tokens)
|
||||
self.model_tokens += tokens
|
||||
|
||||
while len(tokens) > 0:
|
||||
@@ -104,23 +105,24 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
|
||||
if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS:
|
||||
out[self.model_tokens[-1]] = -999999999
|
||||
return out
|
||||
return out, token_len
|
||||
|
||||
def get_embedding(self, input: str, fast_mode: bool) -> List[float]:
|
||||
def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]:
|
||||
if fast_mode:
|
||||
embedding = self.fast_embedding(
|
||||
embedding, token_len = self.fast_embedding(
|
||||
self.fix_tokens(self.pipeline.encode(input)), None
|
||||
)
|
||||
else:
|
||||
self.model_state = None
|
||||
self.model_tokens = []
|
||||
self.run_rnn(self.fix_tokens(self.pipeline.encode(input)))
|
||||
_, token_len = self.run_rnn(self.fix_tokens(self.pipeline.encode(input)))
|
||||
embedding = self.model_state[-5].tolist()
|
||||
embedding = (embedding / np.linalg.norm(embedding)).tolist()
|
||||
return embedding
|
||||
return embedding, token_len
|
||||
|
||||
def fast_embedding(self, tokens: List[str], state):
|
||||
tokens = [int(x) for x in tokens]
|
||||
token_len = len(tokens)
|
||||
self = self.model
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -253,7 +255,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
ory,
|
||||
)
|
||||
|
||||
return state[0].tolist()
|
||||
return state[0].tolist(), token_len
|
||||
|
||||
def generate(self, prompt: str, stop: str = None):
|
||||
quick_log(None, None, "Generation Prompt:\n" + prompt)
|
||||
@@ -274,8 +276,11 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
self.model_tokens = copy.deepcopy(cache["tokens"])
|
||||
logits = copy.deepcopy(cache["logits"])
|
||||
|
||||
prompt_token_len = 0
|
||||
if delta_prompt != "":
|
||||
logits = self.run_rnn(self.fix_tokens(self.pipeline.encode(delta_prompt)))
|
||||
logits, prompt_token_len = self.run_rnn(
|
||||
self.fix_tokens(self.pipeline.encode(delta_prompt))
|
||||
)
|
||||
try:
|
||||
state_cache.add_state(
|
||||
state_cache.AddStateBody(
|
||||
@@ -293,6 +298,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
|
||||
occurrence: Dict = {}
|
||||
|
||||
completion_token_len = 0
|
||||
response = ""
|
||||
for i in range(self.max_tokens_per_generation):
|
||||
for n in occurrence:
|
||||
@@ -305,14 +311,15 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
)
|
||||
|
||||
if token == END_OF_TEXT:
|
||||
yield response, ""
|
||||
yield response, "", prompt_token_len, completion_token_len
|
||||
break
|
||||
if token not in occurrence:
|
||||
occurrence[token] = 1
|
||||
else:
|
||||
occurrence[token] += 1
|
||||
|
||||
logits = self.run_rnn([token])
|
||||
logits, _ = self.run_rnn([token])
|
||||
completion_token_len = completion_token_len + 1
|
||||
delta: str = self.pipeline.decode(self.model_tokens[out_last:])
|
||||
if "\ufffd" not in delta: # avoid utf-8 display issues
|
||||
response += delta
|
||||
@@ -330,7 +337,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
)
|
||||
except HTTPException:
|
||||
pass
|
||||
yield response, ""
|
||||
yield response, "", prompt_token_len, completion_token_len
|
||||
break
|
||||
out_last = begin + i + 1
|
||||
if i == self.max_tokens_per_generation - 1:
|
||||
@@ -345,7 +352,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
)
|
||||
except HTTPException:
|
||||
pass
|
||||
yield response, delta
|
||||
yield response, delta, prompt_token_len, completion_token_len
|
||||
|
||||
|
||||
class ModelConfigBody(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user