diff --git a/backend-python/routes/completion.py b/backend-python/routes/completion.py index f2d1b62..af6f47c 100644 --- a/backend-python/routes/completion.py +++ b/backend-python/routes/completion.py @@ -109,8 +109,8 @@ async def eval_rwkv( set_rwkv_config(model, global_var.get(global_var.Model_Config)) set_rwkv_config(model, body) - response = "" - for response, delta in model.generate( + response, prompt_tokens, completion_tokens = "", 0, 0 + for response, delta, prompt_tokens, completion_tokens in model.generate( prompt, stop=stop, ): @@ -184,6 +184,11 @@ async def eval_rwkv( "object": "chat.completion" if chat_mode else "text_completion", "response": response, "model": model.name, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, "choices": [ { "message": { @@ -384,6 +389,7 @@ async def embeddings(body: EmbeddingsBody, request: Request): base64_format = True embeddings = [] + prompt_tokens = 0 if type(body.input) == list: if type(body.input[0]) == list: encoding = tiktoken.model.encoding_for_model("text-embedding-ada-002") @@ -391,7 +397,8 @@ async def embeddings(body: EmbeddingsBody, request: Request): if await request.is_disconnected(): break input = encoding.decode(body.input[i]) - embedding = model.get_embedding(input, body.fast_mode) + embedding, token_len = model.get_embedding(input, body.fast_mode) + prompt_tokens = prompt_tokens + token_len if base64_format: embedding = embedding_base64(embedding) embeddings.append(embedding) @@ -399,12 +406,15 @@ async def embeddings(body: EmbeddingsBody, request: Request): for i in range(len(body.input)): if await request.is_disconnected(): break - embedding = model.get_embedding(body.input[i], body.fast_mode) + embedding, token_len = model.get_embedding( + body.input[i], body.fast_mode + ) + prompt_tokens = prompt_tokens + token_len if base64_format: embedding = embedding_base64(embedding) embeddings.append(embedding) else: - embedding = model.get_embedding(body.input, body.fast_mode) + embedding, prompt_tokens = model.get_embedding(body.input, body.fast_mode) if base64_format: embedding = embedding_base64(embedding) embeddings.append(embedding) @@ -438,4 +448,5 @@ async def embeddings(body: EmbeddingsBody, request: Request): "object": "list", "data": ret_data, "model": model.name, + "usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}, } diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index be089fa..7f01c35 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -1,7 +1,7 @@ import os import pathlib import copy -from typing import Dict, List +from typing import Dict, List, Tuple from utils.log import quick_log from fastapi import HTTPException from pydantic import BaseModel, Field @@ -71,7 +71,7 @@ The following is a coherent verbose detailed conversation between a girl named { else f"{user}{interface} hi\n\n{bot}{interface} Hi. " + "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n" ) - logits = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system))) + logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system))) try: state_cache.add_state( state_cache.AddStateBody( @@ -92,6 +92,7 @@ The following is a coherent verbose detailed conversation between a girl named { def run_rnn(self, _tokens: List[str], newline_adj: int = 0): tokens = [int(x) for x in _tokens] + token_len = len(tokens) self.model_tokens += tokens while len(tokens) > 0: @@ -104,23 +105,24 @@ The following is a coherent verbose detailed conversation between a girl named { if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS: out[self.model_tokens[-1]] = -999999999 - return out + return out, token_len - def get_embedding(self, input: str, fast_mode: bool) -> List[float]: + def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]: if fast_mode: - embedding = self.fast_embedding( + embedding, token_len = self.fast_embedding( self.fix_tokens(self.pipeline.encode(input)), None ) else: self.model_state = None self.model_tokens = [] - self.run_rnn(self.fix_tokens(self.pipeline.encode(input))) + _, token_len = self.run_rnn(self.fix_tokens(self.pipeline.encode(input))) embedding = self.model_state[-5].tolist() embedding = (embedding / np.linalg.norm(embedding)).tolist() - return embedding + return embedding, token_len def fast_embedding(self, tokens: List[str], state): tokens = [int(x) for x in tokens] + token_len = len(tokens) self = self.model with torch.no_grad(): @@ -253,7 +255,7 @@ The following is a coherent verbose detailed conversation between a girl named { ory, ) - return state[0].tolist() + return state[0].tolist(), token_len def generate(self, prompt: str, stop: str = None): quick_log(None, None, "Generation Prompt:\n" + prompt) @@ -274,8 +276,11 @@ The following is a coherent verbose detailed conversation between a girl named { self.model_tokens = copy.deepcopy(cache["tokens"]) logits = copy.deepcopy(cache["logits"]) + prompt_token_len = 0 if delta_prompt != "": - logits = self.run_rnn(self.fix_tokens(self.pipeline.encode(delta_prompt))) + logits, prompt_token_len = self.run_rnn( + self.fix_tokens(self.pipeline.encode(delta_prompt)) + ) try: state_cache.add_state( state_cache.AddStateBody( @@ -293,6 +298,7 @@ The following is a coherent verbose detailed conversation between a girl named { occurrence: Dict = {} + completion_token_len = 0 response = "" for i in range(self.max_tokens_per_generation): for n in occurrence: @@ -305,14 +311,15 @@ The following is a coherent verbose detailed conversation between a girl named { ) if token == END_OF_TEXT: - yield response, "" + yield response, "", prompt_token_len, completion_token_len break if token not in occurrence: occurrence[token] = 1 else: occurrence[token] += 1 - logits = self.run_rnn([token]) + logits, _ = self.run_rnn([token]) + completion_token_len = completion_token_len + 1 delta: str = self.pipeline.decode(self.model_tokens[out_last:]) if "\ufffd" not in delta: # avoid utf-8 display issues response += delta @@ -330,7 +337,7 @@ The following is a coherent verbose detailed conversation between a girl named { ) except HTTPException: pass - yield response, "" + yield response, "", prompt_token_len, completion_token_len break out_last = begin + i + 1 if i == self.max_tokens_per_generation - 1: @@ -345,7 +352,7 @@ The following is a coherent verbose detailed conversation between a girl named { ) except HTTPException: pass - yield response, delta + yield response, delta, prompt_token_len, completion_token_len class ModelConfigBody(BaseModel):