add usage

This commit is contained in:
josc146 2023-06-20 15:55:52 +08:00
parent 4b2509e643
commit e93c77394d
2 changed files with 36 additions and 18 deletions

View File

@ -109,8 +109,8 @@ async def eval_rwkv(
set_rwkv_config(model, global_var.get(global_var.Model_Config)) set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body) set_rwkv_config(model, body)
response = "" response, prompt_tokens, completion_tokens = "", 0, 0
for response, delta in model.generate( for response, delta, prompt_tokens, completion_tokens in model.generate(
prompt, prompt,
stop=stop, stop=stop,
): ):
@ -184,6 +184,11 @@ async def eval_rwkv(
"object": "chat.completion" if chat_mode else "text_completion", "object": "chat.completion" if chat_mode else "text_completion",
"response": response, "response": response,
"model": model.name, "model": model.name,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
"choices": [ "choices": [
{ {
"message": { "message": {
@ -384,6 +389,7 @@ async def embeddings(body: EmbeddingsBody, request: Request):
base64_format = True base64_format = True
embeddings = [] embeddings = []
prompt_tokens = 0
if type(body.input) == list: if type(body.input) == list:
if type(body.input[0]) == list: if type(body.input[0]) == list:
encoding = tiktoken.model.encoding_for_model("text-embedding-ada-002") encoding = tiktoken.model.encoding_for_model("text-embedding-ada-002")
@ -391,7 +397,8 @@ async def embeddings(body: EmbeddingsBody, request: Request):
if await request.is_disconnected(): if await request.is_disconnected():
break break
input = encoding.decode(body.input[i]) input = encoding.decode(body.input[i])
embedding = model.get_embedding(input, body.fast_mode) embedding, token_len = model.get_embedding(input, body.fast_mode)
prompt_tokens = prompt_tokens + token_len
if base64_format: if base64_format:
embedding = embedding_base64(embedding) embedding = embedding_base64(embedding)
embeddings.append(embedding) embeddings.append(embedding)
@ -399,12 +406,15 @@ async def embeddings(body: EmbeddingsBody, request: Request):
for i in range(len(body.input)): for i in range(len(body.input)):
if await request.is_disconnected(): if await request.is_disconnected():
break break
embedding = model.get_embedding(body.input[i], body.fast_mode) embedding, token_len = model.get_embedding(
body.input[i], body.fast_mode
)
prompt_tokens = prompt_tokens + token_len
if base64_format: if base64_format:
embedding = embedding_base64(embedding) embedding = embedding_base64(embedding)
embeddings.append(embedding) embeddings.append(embedding)
else: else:
embedding = model.get_embedding(body.input, body.fast_mode) embedding, prompt_tokens = model.get_embedding(body.input, body.fast_mode)
if base64_format: if base64_format:
embedding = embedding_base64(embedding) embedding = embedding_base64(embedding)
embeddings.append(embedding) embeddings.append(embedding)
@ -438,4 +448,5 @@ async def embeddings(body: EmbeddingsBody, request: Request):
"object": "list", "object": "list",
"data": ret_data, "data": ret_data,
"model": model.name, "model": model.name,
"usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
} }

View File

@ -1,7 +1,7 @@
import os import os
import pathlib import pathlib
import copy import copy
from typing import Dict, List from typing import Dict, List, Tuple
from utils.log import quick_log from utils.log import quick_log
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -71,7 +71,7 @@ The following is a coherent verbose detailed conversation between a girl named {
else f"{user}{interface} hi\n\n{bot}{interface} Hi. " else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n" + "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
) )
logits = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system))) logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
try: try:
state_cache.add_state( state_cache.add_state(
state_cache.AddStateBody( state_cache.AddStateBody(
@ -92,6 +92,7 @@ The following is a coherent verbose detailed conversation between a girl named {
def run_rnn(self, _tokens: List[str], newline_adj: int = 0): def run_rnn(self, _tokens: List[str], newline_adj: int = 0):
tokens = [int(x) for x in _tokens] tokens = [int(x) for x in _tokens]
token_len = len(tokens)
self.model_tokens += tokens self.model_tokens += tokens
while len(tokens) > 0: while len(tokens) > 0:
@ -104,23 +105,24 @@ The following is a coherent verbose detailed conversation between a girl named {
if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS: if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS:
out[self.model_tokens[-1]] = -999999999 out[self.model_tokens[-1]] = -999999999
return out return out, token_len
def get_embedding(self, input: str, fast_mode: bool) -> List[float]: def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]:
if fast_mode: if fast_mode:
embedding = self.fast_embedding( embedding, token_len = self.fast_embedding(
self.fix_tokens(self.pipeline.encode(input)), None self.fix_tokens(self.pipeline.encode(input)), None
) )
else: else:
self.model_state = None self.model_state = None
self.model_tokens = [] self.model_tokens = []
self.run_rnn(self.fix_tokens(self.pipeline.encode(input))) _, token_len = self.run_rnn(self.fix_tokens(self.pipeline.encode(input)))
embedding = self.model_state[-5].tolist() embedding = self.model_state[-5].tolist()
embedding = (embedding / np.linalg.norm(embedding)).tolist() embedding = (embedding / np.linalg.norm(embedding)).tolist()
return embedding return embedding, token_len
def fast_embedding(self, tokens: List[str], state): def fast_embedding(self, tokens: List[str], state):
tokens = [int(x) for x in tokens] tokens = [int(x) for x in tokens]
token_len = len(tokens)
self = self.model self = self.model
with torch.no_grad(): with torch.no_grad():
@ -253,7 +255,7 @@ The following is a coherent verbose detailed conversation between a girl named {
ory, ory,
) )
return state[0].tolist() return state[0].tolist(), token_len
def generate(self, prompt: str, stop: str = None): def generate(self, prompt: str, stop: str = None):
quick_log(None, None, "Generation Prompt:\n" + prompt) quick_log(None, None, "Generation Prompt:\n" + prompt)
@ -274,8 +276,11 @@ The following is a coherent verbose detailed conversation between a girl named {
self.model_tokens = copy.deepcopy(cache["tokens"]) self.model_tokens = copy.deepcopy(cache["tokens"])
logits = copy.deepcopy(cache["logits"]) logits = copy.deepcopy(cache["logits"])
prompt_token_len = 0
if delta_prompt != "": if delta_prompt != "":
logits = self.run_rnn(self.fix_tokens(self.pipeline.encode(delta_prompt))) logits, prompt_token_len = self.run_rnn(
self.fix_tokens(self.pipeline.encode(delta_prompt))
)
try: try:
state_cache.add_state( state_cache.add_state(
state_cache.AddStateBody( state_cache.AddStateBody(
@ -293,6 +298,7 @@ The following is a coherent verbose detailed conversation between a girl named {
occurrence: Dict = {} occurrence: Dict = {}
completion_token_len = 0
response = "" response = ""
for i in range(self.max_tokens_per_generation): for i in range(self.max_tokens_per_generation):
for n in occurrence: for n in occurrence:
@ -305,14 +311,15 @@ The following is a coherent verbose detailed conversation between a girl named {
) )
if token == END_OF_TEXT: if token == END_OF_TEXT:
yield response, "" yield response, "", prompt_token_len, completion_token_len
break break
if token not in occurrence: if token not in occurrence:
occurrence[token] = 1 occurrence[token] = 1
else: else:
occurrence[token] += 1 occurrence[token] += 1
logits = self.run_rnn([token]) logits, _ = self.run_rnn([token])
completion_token_len = completion_token_len + 1
delta: str = self.pipeline.decode(self.model_tokens[out_last:]) delta: str = self.pipeline.decode(self.model_tokens[out_last:])
if "\ufffd" not in delta: # avoid utf-8 display issues if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta response += delta
@ -330,7 +337,7 @@ The following is a coherent verbose detailed conversation between a girl named {
) )
except HTTPException: except HTTPException:
pass pass
yield response, "" yield response, "", prompt_token_len, completion_token_len
break break
out_last = begin + i + 1 out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1: if i == self.max_tokens_per_generation - 1:
@ -345,7 +352,7 @@ The following is a coherent verbose detailed conversation between a girl named {
) )
except HTTPException: except HTTPException:
pass pass
yield response, delta yield response, delta, prompt_token_len, completion_token_len
class ModelConfigBody(BaseModel): class ModelConfigBody(BaseModel):