diff --git a/backend-python/dep_check.py b/backend-python/dep_check.py index ee24f15..82c10c2 100644 --- a/backend-python/dep_check.py +++ b/backend-python/dep_check.py @@ -1,4 +1,6 @@ +import tiktoken import GPUtil + import torch import rwkv import fastapi diff --git a/backend-python/requirements.txt b/backend-python/requirements.txt index b3239db..d9030ae 100644 Binary files a/backend-python/requirements.txt and b/backend-python/requirements.txt differ diff --git a/backend-python/requirements_versions.txt b/backend-python/requirements_versions.txt index 76e6989..3cb41fd 100644 Binary files a/backend-python/requirements_versions.txt and b/backend-python/requirements_versions.txt differ diff --git a/backend-python/requirements_without_cyac.txt b/backend-python/requirements_without_cyac.txt index bad09f4..34e0f0c 100644 Binary files a/backend-python/requirements_without_cyac.txt and b/backend-python/requirements_without_cyac.txt differ diff --git a/backend-python/routes/completion.py b/backend-python/routes/completion.py index 841028d..f2d1b62 100644 --- a/backend-python/routes/completion.py +++ b/backend-python/routes/completion.py @@ -2,10 +2,13 @@ import asyncio import json from threading import Lock from typing import List +import base64 from fastapi import APIRouter, Request, status, HTTPException from sse_starlette.sse import EventSourceResponse from pydantic import BaseModel +import numpy as np +import tiktoken from utils.rwkv import * from utils.log import quick_log import global_var @@ -116,6 +119,9 @@ async def eval_rwkv( if stream: yield json.dumps( { + "object": "chat.completion.chunk" + if chat_mode + else "text_completion", "response": response, "model": model.name, "choices": [ @@ -152,6 +158,9 @@ async def eval_rwkv( if stream: yield json.dumps( { + "object": "chat.completion.chunk" + if chat_mode + else "text_completion", "response": response, "model": model.name, "choices": [ @@ -172,6 +181,7 @@ async def eval_rwkv( yield "[DONE]" else: yield { + "object": "chat.completion" if chat_mode else "text_completion", "response": response, "model": model.name, "choices": [ @@ -307,3 +317,125 @@ async def completions(body: CompletionBody, request: Request): ).__anext__() except StopAsyncIteration: return None + + +class EmbeddingsBody(BaseModel): + input: str | List[str] | List[List[int]] + model: str = "rwkv" + encoding_format: str = None + fast_mode: bool = False + + class Config: + schema_extra = { + "example": { + "input": "a big apple", + "model": "rwkv", + "encoding_format": None, + "fast_mode": False, + } + } + + +def embedding_base64(embedding: List[float]) -> str: + return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8") + + +@router.post("/v1/embeddings") +@router.post("/embeddings") +@router.post("/v1/engines/text-embedding-ada-002/embeddings") +@router.post("/engines/text-embedding-ada-002/embeddings") +async def embeddings(body: EmbeddingsBody, request: Request): + model: RWKV = global_var.get(global_var.Model) + if model is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded") + + if body.input is None or body.input == "" or body.input == [] or body.input == [[]]: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "input not found") + + global requests_num + requests_num = requests_num + 1 + quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num)) + while completion_lock.locked(): + if await request.is_disconnected(): + requests_num = requests_num - 1 + print(f"{request.client} Stop Waiting (Lock)") + quick_log( + request, + None, + "Stop Waiting (Lock). RequestsNum: " + str(requests_num), + ) + return + await asyncio.sleep(0.1) + else: + completion_lock.acquire() + if await request.is_disconnected(): + completion_lock.release() + requests_num = requests_num - 1 + print(f"{request.client} Stop Waiting (Lock)") + quick_log( + request, + None, + "Stop Waiting (Lock). RequestsNum: " + str(requests_num), + ) + return + + base64_format = False + if body.encoding_format == "base64": + base64_format = True + + embeddings = [] + if type(body.input) == list: + if type(body.input[0]) == list: + encoding = tiktoken.model.encoding_for_model("text-embedding-ada-002") + for i in range(len(body.input)): + if await request.is_disconnected(): + break + input = encoding.decode(body.input[i]) + embedding = model.get_embedding(input, body.fast_mode) + if base64_format: + embedding = embedding_base64(embedding) + embeddings.append(embedding) + else: + for i in range(len(body.input)): + if await request.is_disconnected(): + break + embedding = model.get_embedding(body.input[i], body.fast_mode) + if base64_format: + embedding = embedding_base64(embedding) + embeddings.append(embedding) + else: + embedding = model.get_embedding(body.input, body.fast_mode) + if base64_format: + embedding = embedding_base64(embedding) + embeddings.append(embedding) + + requests_num = requests_num - 1 + completion_lock.release() + if await request.is_disconnected(): + print(f"{request.client} Stop Waiting") + quick_log( + request, + None, + "Stop Waiting. RequestsNum: " + str(requests_num), + ) + return + quick_log( + request, + None, + "Finished. RequestsNum: " + str(requests_num), + ) + + ret_data = [ + { + "object": "embedding", + "index": i, + "embedding": embedding, + } + for i, embedding in enumerate(embeddings) + ] + + return { + "object": "list", + "data": ret_data, + "model": model.name, + } diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index 208a961..be089fa 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -5,6 +5,8 @@ from typing import Dict, List from utils.log import quick_log from fastapi import HTTPException from pydantic import BaseModel, Field +import torch +import numpy as np from rwkv_pip.utils import PIPELINE from routes import state_cache @@ -104,6 +106,155 @@ The following is a coherent verbose detailed conversation between a girl named { out[self.model_tokens[-1]] = -999999999 return out + def get_embedding(self, input: str, fast_mode: bool) -> List[float]: + if fast_mode: + embedding = self.fast_embedding( + self.fix_tokens(self.pipeline.encode(input)), None + ) + else: + self.model_state = None + self.model_tokens = [] + self.run_rnn(self.fix_tokens(self.pipeline.encode(input))) + embedding = self.model_state[-5].tolist() + embedding = (embedding / np.linalg.norm(embedding)).tolist() + return embedding + + def fast_embedding(self, tokens: List[str], state): + tokens = [int(x) for x in tokens] + self = self.model + + with torch.no_grad(): + w = self.w + args = self.args + + if state == None: + state = [None] * args.n_layer * 5 + for i in range( + args.n_layer + ): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i * 5 + 0] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + state[i * 5 + 1] = torch.zeros( + args.n_embd, dtype=torch.float, requires_grad=False, device=dev + ).contiguous() + state[i * 5 + 2] = torch.zeros( + args.n_embd, dtype=torch.float, requires_grad=False, device=dev + ).contiguous() + state[i * 5 + 3] = ( + torch.zeros( + args.n_embd, + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() + - 1e30 + ) + state[i * 5 + 4] = torch.zeros( + args.n_embd, dtype=atype, requires_grad=False, device=dev + ).contiguous() + + break + + seq_mode = len(tokens) > 1 + + x = w["emb.weight"][tokens if seq_mode else tokens[0]] + + for i in range(args.n_layer): + bbb = f"blocks.{i}." + att = f"blocks.{i}.att." + ffn = f"blocks.{i}.ffn." + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + wtype = dd.wtype + if seq_mode: + if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1": + ATT = ( + self.cuda_att_seq + if wtype != torch.uint8 + else self.cuda_att_seq_i8 + ) + else: + ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8 + FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 + else: + ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 + FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 + + x = x.to(dtype=atype, device=dev) + + kw = w[f"{att}key.weight"] + vw = w[f"{att}value.weight"] + rw = w[f"{att}receptance.weight"] + ow = w[f"{att}output.weight"] + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + ow = ow.to(device=dev, non_blocking=True) + kmx = w[f"{att}key.weight_mx"] if wtype == torch.uint8 else x + krx = w[f"{att}key.weight_rx"] if wtype == torch.uint8 else x + kmy = w[f"{att}key.weight_my"] if wtype == torch.uint8 else x + kry = w[f"{att}key.weight_ry"] if wtype == torch.uint8 else x + vmx = w[f"{att}value.weight_mx"] if wtype == torch.uint8 else x + vrx = w[f"{att}value.weight_rx"] if wtype == torch.uint8 else x + vmy = w[f"{att}value.weight_my"] if wtype == torch.uint8 else x + vry = w[f"{att}value.weight_ry"] if wtype == torch.uint8 else x + rmx = w[f"{att}receptance.weight_mx"] if wtype == torch.uint8 else x + rrx = w[f"{att}receptance.weight_rx"] if wtype == torch.uint8 else x + rmy = w[f"{att}receptance.weight_my"] if wtype == torch.uint8 else x + rry = w[f"{att}receptance.weight_ry"] if wtype == torch.uint8 else x + omx = w[f"{att}output.weight_mx"] if wtype == torch.uint8 else x + orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x + omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x + ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x + ( + x, + state[i * 5 + 0], + state[i * 5 + 1], + state[i * 5 + 2], + state[i * 5 + 3], + ) = ATT( + x, + state[i * 5 + 0], + state[i * 5 + 1], + state[i * 5 + 2], + state[i * 5 + 3], + w[f"{bbb}ln1.weight"], + w[f"{bbb}ln1.bias"], + w[f"{att}time_mix_k"], + w[f"{att}time_mix_v"], + w[f"{att}time_mix_r"], + w[f"{att}time_decay"], + w[f"{att}time_first"], + kw, + vw, + rw, + ow, + kmx, + krx, + kmy, + kry, + vmx, + vrx, + vmy, + vry, + rmx, + rrx, + rmy, + rry, + omx, + orx, + omy, + ory, + ) + + return state[0].tolist() + def generate(self, prompt: str, stop: str = None): quick_log(None, None, "Generation Prompt:\n" + prompt) cache = None