embeddings api compatible with openai api and langchain(sdk)

This commit is contained in:
josc146
2023-06-19 22:51:06 +08:00
parent 377f71b16b
commit 8963543159
6 changed files with 285 additions and 0 deletions

View File

@@ -2,10 +2,13 @@ import asyncio
import json
from threading import Lock
from typing import List
import base64
from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
import numpy as np
import tiktoken
from utils.rwkv import *
from utils.log import quick_log
import global_var
@@ -116,6 +119,9 @@ async def eval_rwkv(
if stream:
yield json.dumps(
{
"object": "chat.completion.chunk"
if chat_mode
else "text_completion",
"response": response,
"model": model.name,
"choices": [
@@ -152,6 +158,9 @@ async def eval_rwkv(
if stream:
yield json.dumps(
{
"object": "chat.completion.chunk"
if chat_mode
else "text_completion",
"response": response,
"model": model.name,
"choices": [
@@ -172,6 +181,7 @@ async def eval_rwkv(
yield "[DONE]"
else:
yield {
"object": "chat.completion" if chat_mode else "text_completion",
"response": response,
"model": model.name,
"choices": [
@@ -307,3 +317,125 @@ async def completions(body: CompletionBody, request: Request):
).__anext__()
except StopAsyncIteration:
return None
class EmbeddingsBody(BaseModel):
input: str | List[str] | List[List[int]]
model: str = "rwkv"
encoding_format: str = None
fast_mode: bool = False
class Config:
schema_extra = {
"example": {
"input": "a big apple",
"model": "rwkv",
"encoding_format": None,
"fast_mode": False,
}
}
def embedding_base64(embedding: List[float]) -> str:
return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8")
@router.post("/v1/embeddings")
@router.post("/embeddings")
@router.post("/v1/engines/text-embedding-ada-002/embeddings")
@router.post("/engines/text-embedding-ada-002/embeddings")
async def embeddings(body: EmbeddingsBody, request: Request):
model: RWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
if body.input is None or body.input == "" or body.input == [] or body.input == [[]]:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "input not found")
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked():
if await request.is_disconnected():
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
await asyncio.sleep(0.1)
else:
completion_lock.acquire()
if await request.is_disconnected():
completion_lock.release()
requests_num = requests_num - 1
print(f"{request.client} Stop Waiting (Lock)")
quick_log(
request,
None,
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
)
return
base64_format = False
if body.encoding_format == "base64":
base64_format = True
embeddings = []
if type(body.input) == list:
if type(body.input[0]) == list:
encoding = tiktoken.model.encoding_for_model("text-embedding-ada-002")
for i in range(len(body.input)):
if await request.is_disconnected():
break
input = encoding.decode(body.input[i])
embedding = model.get_embedding(input, body.fast_mode)
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
else:
for i in range(len(body.input)):
if await request.is_disconnected():
break
embedding = model.get_embedding(body.input[i], body.fast_mode)
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
else:
embedding = model.get_embedding(body.input, body.fast_mode)
if base64_format:
embedding = embedding_base64(embedding)
embeddings.append(embedding)
requests_num = requests_num - 1
completion_lock.release()
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
None,
"Stop Waiting. RequestsNum: " + str(requests_num),
)
return
quick_log(
request,
None,
"Finished. RequestsNum: " + str(requests_num),
)
ret_data = [
{
"object": "embedding",
"index": i,
"embedding": embedding,
}
for i, embedding in enumerate(embeddings)
]
return {
"object": "list",
"data": ret_data,
"model": model.name,
}