add usage
This commit is contained in:
@@ -109,8 +109,8 @@ async def eval_rwkv(
|
||||
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
||||
set_rwkv_config(model, body)
|
||||
|
||||
response = ""
|
||||
for response, delta in model.generate(
|
||||
response, prompt_tokens, completion_tokens = "", 0, 0
|
||||
for response, delta, prompt_tokens, completion_tokens in model.generate(
|
||||
prompt,
|
||||
stop=stop,
|
||||
):
|
||||
@@ -184,6 +184,11 @@ async def eval_rwkv(
|
||||
"object": "chat.completion" if chat_mode else "text_completion",
|
||||
"response": response,
|
||||
"model": model.name,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
@@ -384,6 +389,7 @@ async def embeddings(body: EmbeddingsBody, request: Request):
|
||||
base64_format = True
|
||||
|
||||
embeddings = []
|
||||
prompt_tokens = 0
|
||||
if type(body.input) == list:
|
||||
if type(body.input[0]) == list:
|
||||
encoding = tiktoken.model.encoding_for_model("text-embedding-ada-002")
|
||||
@@ -391,7 +397,8 @@ async def embeddings(body: EmbeddingsBody, request: Request):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
input = encoding.decode(body.input[i])
|
||||
embedding = model.get_embedding(input, body.fast_mode)
|
||||
embedding, token_len = model.get_embedding(input, body.fast_mode)
|
||||
prompt_tokens = prompt_tokens + token_len
|
||||
if base64_format:
|
||||
embedding = embedding_base64(embedding)
|
||||
embeddings.append(embedding)
|
||||
@@ -399,12 +406,15 @@ async def embeddings(body: EmbeddingsBody, request: Request):
|
||||
for i in range(len(body.input)):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
embedding = model.get_embedding(body.input[i], body.fast_mode)
|
||||
embedding, token_len = model.get_embedding(
|
||||
body.input[i], body.fast_mode
|
||||
)
|
||||
prompt_tokens = prompt_tokens + token_len
|
||||
if base64_format:
|
||||
embedding = embedding_base64(embedding)
|
||||
embeddings.append(embedding)
|
||||
else:
|
||||
embedding = model.get_embedding(body.input, body.fast_mode)
|
||||
embedding, prompt_tokens = model.get_embedding(body.input, body.fast_mode)
|
||||
if base64_format:
|
||||
embedding = embedding_base64(embedding)
|
||||
embeddings.append(embedding)
|
||||
@@ -438,4 +448,5 @@ async def embeddings(body: EmbeddingsBody, request: Request):
|
||||
"object": "list",
|
||||
"data": ret_data,
|
||||
"model": model.name,
|
||||
"usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user