embeddings api compatible with openai api and langchain(sdk)

This commit is contained in:
josc146 2023-06-19 22:51:06 +08:00
parent 377f71b16b
commit 8963543159
6 changed files with 285 additions and 0 deletions

View File

@ -1,4 +1,6 @@
import tiktoken
import GPUtil import GPUtil
import torch import torch
import rwkv import rwkv
import fastapi import fastapi

Binary file not shown.

View File

@ -2,10 +2,13 @@ import asyncio
import json import json
from threading import Lock from threading import Lock
from typing import List from typing import List
import base64
from fastapi import APIRouter, Request, status, HTTPException from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel from pydantic import BaseModel
import numpy as np
import tiktoken
from utils.rwkv import * from utils.rwkv import *
from utils.log import quick_log from utils.log import quick_log
import global_var import global_var
@ -116,6 +119,9 @@ async def eval_rwkv(
if stream: if stream:
yield json.dumps( yield json.dumps(
{ {
"object": "chat.completion.chunk"
if chat_mode
else "text_completion",
"response": response, "response": response,
"model": model.name, "model": model.name,
"choices": [ "choices": [
@ -152,6 +158,9 @@ async def eval_rwkv(
if stream: if stream:
yield json.dumps( yield json.dumps(
{ {
"object": "chat.completion.chunk"
if chat_mode
else "text_completion",
"response": response, "response": response,
"model": model.name, "model": model.name,
"choices": [ "choices": [
@ -172,6 +181,7 @@ async def eval_rwkv(
yield "[DONE]" yield "[DONE]"
else: else:
yield { yield {
"object": "chat.completion" if chat_mode else "text_completion",
"response": response, "response": response,
"model": model.name, "model": model.name,
"choices": [ "choices": [
@ -307,3 +317,125 @@ async def completions(body: CompletionBody, request: Request):
).__anext__() ).__anext__()
except StopAsyncIteration: except StopAsyncIteration:
return None return None
class EmbeddingsBody(BaseModel):
input: str | List[str] | List[List[int]]
model: str = "rwkv"
encoding_format: str = None
fast_mode: bool = False
class Config:
schema_extra = {
"example": {
"input": "a big apple",
"model": "rwkv",
"encoding_format": None,
"fast_mode": False,
}
}
def embedding_base64(embedding: List[float]) -> str:
return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8")
@router.post("/v1/embeddings")
@router.post("/embeddings")
@router.post("/v1/engines/text-embedding-ada-002/embeddings")
@router.post("/engines/text-embedding-ada-002/embeddings")
async def embeddings(body: EmbeddingsBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
if body.input is None or body.input == "" or body.input == [] or body.input == [[]]:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "input not found")
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked():
if await request.is_disconnected():
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
await asyncio.sleep(0.1)
else:
completion_lock.acquire()
if await request.is_disconnected():
completion_lock.release()
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
base64_format = False
if body.encoding_format == "base64":
base64_format = True
embeddings = []
if type(body.input) == list:
if type(body.input[0]) == list:
encoding = tiktoken.model.encoding_for_model("text-embedding-ada-002")
for i in range(len(body.input)):
if await request.is_disconnected():
break
input = encoding.decode(body.input[i])
embedding = model.get_embedding(input, body.fast_mode)
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
else:
for i in range(len(body.input)):
if await request.is_disconnected():
break
embedding = model.get_embedding(body.input[i], body.fast_mode)
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
else:
embedding = model.get_embedding(body.input, body.fast_mode)
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
requests_num = requests_num - 1
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
None,
"Stop Waiting. RequestsNum: " + str(requests_num),
)
return
quick_log(
request,
None,
"Finished. RequestsNum: " + str(requests_num),
)
ret_data = [
{
"object": "embedding",
"index": i,
"embedding": embedding,
}
for i, embedding in enumerate(embeddings)
]
return {
"object": "list",
"data": ret_data,
"model": model.name,
}

View File

@ -5,6 +5,8 @@ from typing import Dict, List
from utils.log import quick_log from utils.log import quick_log
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import torch
import numpy as np
from rwkv_pip.utils import PIPELINE from rwkv_pip.utils import PIPELINE
from routes import state_cache from routes import state_cache
@ -104,6 +106,155 @@ The following is a coherent verbose detailed conversation between a girl named {
out[self.model_tokens[-1]] = -999999999 out[self.model_tokens[-1]] = -999999999
return out return out
def get_embedding(self, input: str, fast_mode: bool) -> List[float]:
if fast_mode:
embedding = self.fast_embedding(
self.fix_tokens(self.pipeline.encode(input)), None
)
else:
self.model_state = None
self.model_tokens = []
self.run_rnn(self.fix_tokens(self.pipeline.encode(input)))
embedding = self.model_state[-5].tolist()
embedding = (embedding / np.linalg.norm(embedding)).tolist()
return embedding
def fast_embedding(self, tokens: List[str], state):
tokens = [int(x) for x in tokens]
self = self.model
with torch.no_grad():
w = self.w
args = self.args
if state == None:
state = [None] * args.n_layer * 5
for i in range(
args.n_layer
): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
state[i * 5 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 1] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 2] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 3] = (
torch.zeros(
args.n_embd,
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous()
- 1e30
)
state[i * 5 + 4] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
break
seq_mode = len(tokens) > 1
x = w["emb.weight"][tokens if seq_mode else tokens[0]]
for i in range(args.n_layer):
bbb = f"blocks.{i}."
att = f"blocks.{i}.att."
ffn = f"blocks.{i}.ffn."
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
wtype = dd.wtype
if seq_mode:
if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1":
ATT = (
self.cuda_att_seq
if wtype != torch.uint8
else self.cuda_att_seq_i8
)
else:
ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
else:
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
x = x.to(dtype=atype, device=dev)
kw = w[f"{att}key.weight"]
vw = w[f"{att}value.weight"]
rw = w[f"{att}receptance.weight"]
ow = w[f"{att}output.weight"]
if dd.stream:
kw = kw.to(device=dev, non_blocking=True)
vw = vw.to(device=dev, non_blocking=True)
rw = rw.to(device=dev, non_blocking=True)
ow = ow.to(device=dev, non_blocking=True)
kmx = w[f"{att}key.weight_mx"] if wtype == torch.uint8 else x
krx = w[f"{att}key.weight_rx"] if wtype == torch.uint8 else x
kmy = w[f"{att}key.weight_my"] if wtype == torch.uint8 else x
kry = w[f"{att}key.weight_ry"] if wtype == torch.uint8 else x
vmx = w[f"{att}value.weight_mx"] if wtype == torch.uint8 else x
vrx = w[f"{att}value.weight_rx"] if wtype == torch.uint8 else x
vmy = w[f"{att}value.weight_my"] if wtype == torch.uint8 else x
vry = w[f"{att}value.weight_ry"] if wtype == torch.uint8 else x
rmx = w[f"{att}receptance.weight_mx"] if wtype == torch.uint8 else x
rrx = w[f"{att}receptance.weight_rx"] if wtype == torch.uint8 else x
rmy = w[f"{att}receptance.weight_my"] if wtype == torch.uint8 else x
rry = w[f"{att}receptance.weight_ry"] if wtype == torch.uint8 else x
omx = w[f"{att}output.weight_mx"] if wtype == torch.uint8 else x
orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
) = ATT(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
w[f"{bbb}ln1.weight"],
w[f"{bbb}ln1.bias"],
w[f"{att}time_mix_k"],
w[f"{att}time_mix_v"],
w[f"{att}time_mix_r"],
w[f"{att}time_decay"],
w[f"{att}time_first"],
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
)
return state[0].tolist()
def generate(self, prompt: str, stop: str = None): def generate(self, prompt: str, stop: str = None):
quick_log(None, None, "Generation Prompt:\n" + prompt) quick_log(None, None, "Generation Prompt:\n" + prompt)
cache = None cache = None