embeddings api compatible with openai api and langchain(sdk)

This commit is contained in:
josc146 2023-06-19 22:51:06 +08:00
parent 377f71b16b
commit 8963543159
6 changed files with 285 additions and 0 deletions

View File

@ -1,4 +1,6 @@
import tiktoken
import GPUtil
import torch
import rwkv
import fastapi

Binary file not shown.

View File

@ -2,10 +2,13 @@ import asyncio
import json
from threading import Lock
from typing import List
import base64
from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
import numpy as np
import tiktoken
from utils.rwkv import *
from utils.log import quick_log
import global_var
@ -116,6 +119,9 @@ async def eval_rwkv(
if stream:
yield json.dumps(
{
"object": "chat.completion.chunk"
if chat_mode
else "text_completion",
"response": response,
"model": model.name,
"choices": [
@ -152,6 +158,9 @@ async def eval_rwkv(
if stream:
yield json.dumps(
{
"object": "chat.completion.chunk"
if chat_mode
else "text_completion",
"response": response,
"model": model.name,
"choices": [
@ -172,6 +181,7 @@ async def eval_rwkv(
yield "[DONE]"
else:
yield {
"object": "chat.completion" if chat_mode else "text_completion",
"response": response,
"model": model.name,
"choices": [
@ -307,3 +317,125 @@ async def completions(body: CompletionBody, request: Request):
).__anext__()
except StopAsyncIteration:
return None
class EmbeddingsBody(BaseModel):
input: str | List[str] | List[List[int]]
model: str = "rwkv"
encoding_format: str = None
fast_mode: bool = False
class Config:
schema_extra = {
"example": {
"input": "a big apple",
"model": "rwkv",
"encoding_format": None,
"fast_mode": False,
}
}
def embedding_base64(embedding: List[float]) -> str:
return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8")
@router.post("/v1/embeddings")
@router.post("/embeddings")
@router.post("/v1/engines/text-embedding-ada-002/embeddings")
@router.post("/engines/text-embedding-ada-002/embeddings")
async def embeddings(body: EmbeddingsBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
if body.input is None or body.input == "" or body.input == [] or body.input == [[]]:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "input not found")
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked():
if await request.is_disconnected():
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
await asyncio.sleep(0.1)
else:
completion_lock.acquire()
if await request.is_disconnected():
completion_lock.release()
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
base64_format = False
if body.encoding_format == "base64":
base64_format = True
embeddings = []
if type(body.input) == list:
if type(body.input[0]) == list:
encoding = tiktoken.model.encoding_for_model("text-embedding-ada-002")
for i in range(len(body.input)):
if await request.is_disconnected():
break
input = encoding.decode(body.input[i])
embedding = model.get_embedding(input, body.fast_mode)
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
else:
for i in range(len(body.input)):
if await request.is_disconnected():
break
embedding = model.get_embedding(body.input[i], body.fast_mode)
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
else:
embedding = model.get_embedding(body.input, body.fast_mode)
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
requests_num = requests_num - 1
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
None,
"Stop Waiting. RequestsNum: " + str(requests_num),
)
return
quick_log(
request,
None,
"Finished. RequestsNum: " + str(requests_num),
)
ret_data = [
{
"object": "embedding",
"index": i,
"embedding": embedding,
}
for i, embedding in enumerate(embeddings)
]
return {
"object": "list",
"data": ret_data,
"model": model.name,
}

View File

@ -5,6 +5,8 @@ from typing import Dict, List
from utils.log import quick_log
from fastapi import HTTPException
from pydantic import BaseModel, Field
import torch
import numpy as np
from rwkv_pip.utils import PIPELINE
from routes import state_cache
@ -104,6 +106,155 @@ The following is a coherent verbose detailed conversation between a girl named {
out[self.model_tokens[-1]] = -999999999
return out
def get_embedding(self, input: str, fast_mode: bool) -> List[float]:
if fast_mode:
embedding = self.fast_embedding(
self.fix_tokens(self.pipeline.encode(input)), None
)
else:
self.model_state = None
self.model_tokens = []
self.run_rnn(self.fix_tokens(self.pipeline.encode(input)))
embedding = self.model_state[-5].tolist()
embedding = (embedding / np.linalg.norm(embedding)).tolist()
return embedding
def fast_embedding(self, tokens: List[str], state):
tokens = [int(x) for x in tokens]
self = self.model
with torch.no_grad():
w = self.w
args = self.args
if state == None:
state = [None] * args.n_layer * 5
for i in range(
args.n_layer
): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
state[i * 5 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 1] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 2] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 3] = (
torch.zeros(
args.n_embd,
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous()
- 1e30
)
state[i * 5 + 4] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
break
seq_mode = len(tokens) > 1
x = w["emb.weight"][tokens if seq_mode else tokens[0]]
for i in range(args.n_layer):
bbb = f"blocks.{i}."
att = f"blocks.{i}.att."
ffn = f"blocks.{i}.ffn."
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
wtype = dd.wtype
if seq_mode:
if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1":
ATT = (
self.cuda_att_seq
if wtype != torch.uint8
else self.cuda_att_seq_i8
)
else:
ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
else:
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
x = x.to(dtype=atype, device=dev)
kw = w[f"{att}key.weight"]
vw = w[f"{att}value.weight"]
rw = w[f"{att}receptance.weight"]
ow = w[f"{att}output.weight"]
if dd.stream:
kw = kw.to(device=dev, non_blocking=True)
vw = vw.to(device=dev, non_blocking=True)
rw = rw.to(device=dev, non_blocking=True)
ow = ow.to(device=dev, non_blocking=True)
kmx = w[f"{att}key.weight_mx"] if wtype == torch.uint8 else x
krx = w[f"{att}key.weight_rx"] if wtype == torch.uint8 else x
kmy = w[f"{att}key.weight_my"] if wtype == torch.uint8 else x
kry = w[f"{att}key.weight_ry"] if wtype == torch.uint8 else x
vmx = w[f"{att}value.weight_mx"] if wtype == torch.uint8 else x
vrx = w[f"{att}value.weight_rx"] if wtype == torch.uint8 else x
vmy = w[f"{att}value.weight_my"] if wtype == torch.uint8 else x
vry = w[f"{att}value.weight_ry"] if wtype == torch.uint8 else x
rmx = w[f"{att}receptance.weight_mx"] if wtype == torch.uint8 else x
rrx = w[f"{att}receptance.weight_rx"] if wtype == torch.uint8 else x
rmy = w[f"{att}receptance.weight_my"] if wtype == torch.uint8 else x
rry = w[f"{att}receptance.weight_ry"] if wtype == torch.uint8 else x
omx = w[f"{att}output.weight_mx"] if wtype == torch.uint8 else x
orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
) = ATT(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
w[f"{bbb}ln1.weight"],
w[f"{bbb}ln1.bias"],
w[f"{att}time_mix_k"],
w[f"{att}time_mix_v"],
w[f"{att}time_mix_r"],
w[f"{att}time_decay"],
w[f"{att}time_first"],
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
)
return state[0].tolist()
def generate(self, prompt: str, stop: str = None):
quick_log(None, None, "Generation Prompt:\n" + prompt)
cache = None