embeddings api compatible with openai api and langchain(sdk)
This commit is contained in:
parent
377f71b16b
commit
8963543159
@ -1,4 +1,6 @@
|
|||||||
|
import tiktoken
|
||||||
import GPUtil
|
import GPUtil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import rwkv
|
import rwkv
|
||||||
import fastapi
|
import fastapi
|
||||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -2,10 +2,13 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import base64
|
||||||
|
|
||||||
from fastapi import APIRouter, Request, status, HTTPException
|
from fastapi import APIRouter, Request, status, HTTPException
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
import numpy as np
|
||||||
|
import tiktoken
|
||||||
from utils.rwkv import *
|
from utils.rwkv import *
|
||||||
from utils.log import quick_log
|
from utils.log import quick_log
|
||||||
import global_var
|
import global_var
|
||||||
@ -116,6 +119,9 @@ async def eval_rwkv(
|
|||||||
if stream:
|
if stream:
|
||||||
yield json.dumps(
|
yield json.dumps(
|
||||||
{
|
{
|
||||||
|
"object": "chat.completion.chunk"
|
||||||
|
if chat_mode
|
||||||
|
else "text_completion",
|
||||||
"response": response,
|
"response": response,
|
||||||
"model": model.name,
|
"model": model.name,
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -152,6 +158,9 @@ async def eval_rwkv(
|
|||||||
if stream:
|
if stream:
|
||||||
yield json.dumps(
|
yield json.dumps(
|
||||||
{
|
{
|
||||||
|
"object": "chat.completion.chunk"
|
||||||
|
if chat_mode
|
||||||
|
else "text_completion",
|
||||||
"response": response,
|
"response": response,
|
||||||
"model": model.name,
|
"model": model.name,
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -172,6 +181,7 @@ async def eval_rwkv(
|
|||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
else:
|
else:
|
||||||
yield {
|
yield {
|
||||||
|
"object": "chat.completion" if chat_mode else "text_completion",
|
||||||
"response": response,
|
"response": response,
|
||||||
"model": model.name,
|
"model": model.name,
|
||||||
"choices": [
|
"choices": [
|
||||||
@ -307,3 +317,125 @@ async def completions(body: CompletionBody, request: Request):
|
|||||||
).__anext__()
|
).__anext__()
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsBody(BaseModel):
|
||||||
|
input: str | List[str] | List[List[int]]
|
||||||
|
model: str = "rwkv"
|
||||||
|
encoding_format: str = None
|
||||||
|
fast_mode: bool = False
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"input": "a big apple",
|
||||||
|
"model": "rwkv",
|
||||||
|
"encoding_format": None,
|
||||||
|
"fast_mode": False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def embedding_base64(embedding: List[float]) -> str:
|
||||||
|
return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/embeddings")
|
||||||
|
@router.post("/embeddings")
|
||||||
|
@router.post("/v1/engines/text-embedding-ada-002/embeddings")
|
||||||
|
@router.post("/engines/text-embedding-ada-002/embeddings")
|
||||||
|
async def embeddings(body: EmbeddingsBody, request: Request):
|
||||||
|
model: RWKV = global_var.get(global_var.Model)
|
||||||
|
if model is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
|
||||||
|
|
||||||
|
if body.input is None or body.input == "" or body.input == [] or body.input == [[]]:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "input not found")
|
||||||
|
|
||||||
|
global requests_num
|
||||||
|
requests_num = requests_num + 1
|
||||||
|
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
|
||||||
|
while completion_lock.locked():
|
||||||
|
if await request.is_disconnected():
|
||||||
|
requests_num = requests_num - 1
|
||||||
|
print(f"{request.client} Stop Waiting (Lock)")
|
||||||
|
quick_log(
|
||||||
|
request,
|
||||||
|
None,
|
||||||
|
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
else:
|
||||||
|
completion_lock.acquire()
|
||||||
|
if await request.is_disconnected():
|
||||||
|
completion_lock.release()
|
||||||
|
requests_num = requests_num - 1
|
||||||
|
print(f"{request.client} Stop Waiting (Lock)")
|
||||||
|
quick_log(
|
||||||
|
request,
|
||||||
|
None,
|
||||||
|
"Stop Waiting (Lock). RequestsNum: " + str(requests_num),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
base64_format = False
|
||||||
|
if body.encoding_format == "base64":
|
||||||
|
base64_format = True
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
if type(body.input) == list:
|
||||||
|
if type(body.input[0]) == list:
|
||||||
|
encoding = tiktoken.model.encoding_for_model("text-embedding-ada-002")
|
||||||
|
for i in range(len(body.input)):
|
||||||
|
if await request.is_disconnected():
|
||||||
|
break
|
||||||
|
input = encoding.decode(body.input[i])
|
||||||
|
embedding = model.get_embedding(input, body.fast_mode)
|
||||||
|
if base64_format:
|
||||||
|
embedding = embedding_base64(embedding)
|
||||||
|
embeddings.append(embedding)
|
||||||
|
else:
|
||||||
|
for i in range(len(body.input)):
|
||||||
|
if await request.is_disconnected():
|
||||||
|
break
|
||||||
|
embedding = model.get_embedding(body.input[i], body.fast_mode)
|
||||||
|
if base64_format:
|
||||||
|
embedding = embedding_base64(embedding)
|
||||||
|
embeddings.append(embedding)
|
||||||
|
else:
|
||||||
|
embedding = model.get_embedding(body.input, body.fast_mode)
|
||||||
|
if base64_format:
|
||||||
|
embedding = embedding_base64(embedding)
|
||||||
|
embeddings.append(embedding)
|
||||||
|
|
||||||
|
requests_num = requests_num - 1
|
||||||
|
completion_lock.release()
|
||||||
|
if await request.is_disconnected():
|
||||||
|
print(f"{request.client} Stop Waiting")
|
||||||
|
quick_log(
|
||||||
|
request,
|
||||||
|
None,
|
||||||
|
"Stop Waiting. RequestsNum: " + str(requests_num),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
quick_log(
|
||||||
|
request,
|
||||||
|
None,
|
||||||
|
"Finished. RequestsNum: " + str(requests_num),
|
||||||
|
)
|
||||||
|
|
||||||
|
ret_data = [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": i,
|
||||||
|
"embedding": embedding,
|
||||||
|
}
|
||||||
|
for i, embedding in enumerate(embeddings)
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": ret_data,
|
||||||
|
"model": model.name,
|
||||||
|
}
|
||||||
|
@ -5,6 +5,8 @@ from typing import Dict, List
|
|||||||
from utils.log import quick_log
|
from utils.log import quick_log
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
from rwkv_pip.utils import PIPELINE
|
from rwkv_pip.utils import PIPELINE
|
||||||
from routes import state_cache
|
from routes import state_cache
|
||||||
|
|
||||||
@ -104,6 +106,155 @@ The following is a coherent verbose detailed conversation between a girl named {
|
|||||||
out[self.model_tokens[-1]] = -999999999
|
out[self.model_tokens[-1]] = -999999999
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def get_embedding(self, input: str, fast_mode: bool) -> List[float]:
|
||||||
|
if fast_mode:
|
||||||
|
embedding = self.fast_embedding(
|
||||||
|
self.fix_tokens(self.pipeline.encode(input)), None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model_state = None
|
||||||
|
self.model_tokens = []
|
||||||
|
self.run_rnn(self.fix_tokens(self.pipeline.encode(input)))
|
||||||
|
embedding = self.model_state[-5].tolist()
|
||||||
|
embedding = (embedding / np.linalg.norm(embedding)).tolist()
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def fast_embedding(self, tokens: List[str], state):
|
||||||
|
tokens = [int(x) for x in tokens]
|
||||||
|
self = self.model
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
w = self.w
|
||||||
|
args = self.args
|
||||||
|
|
||||||
|
if state == None:
|
||||||
|
state = [None] * args.n_layer * 5
|
||||||
|
for i in range(
|
||||||
|
args.n_layer
|
||||||
|
): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
|
||||||
|
dd = self.strategy[i]
|
||||||
|
dev = dd.device
|
||||||
|
atype = dd.atype
|
||||||
|
state[i * 5 + 0] = torch.zeros(
|
||||||
|
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||||
|
).contiguous()
|
||||||
|
state[i * 5 + 1] = torch.zeros(
|
||||||
|
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
|
||||||
|
).contiguous()
|
||||||
|
state[i * 5 + 2] = torch.zeros(
|
||||||
|
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
|
||||||
|
).contiguous()
|
||||||
|
state[i * 5 + 3] = (
|
||||||
|
torch.zeros(
|
||||||
|
args.n_embd,
|
||||||
|
dtype=torch.float,
|
||||||
|
requires_grad=False,
|
||||||
|
device=dev,
|
||||||
|
).contiguous()
|
||||||
|
- 1e30
|
||||||
|
)
|
||||||
|
state[i * 5 + 4] = torch.zeros(
|
||||||
|
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
seq_mode = len(tokens) > 1
|
||||||
|
|
||||||
|
x = w["emb.weight"][tokens if seq_mode else tokens[0]]
|
||||||
|
|
||||||
|
for i in range(args.n_layer):
|
||||||
|
bbb = f"blocks.{i}."
|
||||||
|
att = f"blocks.{i}.att."
|
||||||
|
ffn = f"blocks.{i}.ffn."
|
||||||
|
dd = self.strategy[i]
|
||||||
|
dev = dd.device
|
||||||
|
atype = dd.atype
|
||||||
|
wtype = dd.wtype
|
||||||
|
if seq_mode:
|
||||||
|
if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1":
|
||||||
|
ATT = (
|
||||||
|
self.cuda_att_seq
|
||||||
|
if wtype != torch.uint8
|
||||||
|
else self.cuda_att_seq_i8
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8
|
||||||
|
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
|
||||||
|
else:
|
||||||
|
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
|
||||||
|
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
|
||||||
|
|
||||||
|
x = x.to(dtype=atype, device=dev)
|
||||||
|
|
||||||
|
kw = w[f"{att}key.weight"]
|
||||||
|
vw = w[f"{att}value.weight"]
|
||||||
|
rw = w[f"{att}receptance.weight"]
|
||||||
|
ow = w[f"{att}output.weight"]
|
||||||
|
if dd.stream:
|
||||||
|
kw = kw.to(device=dev, non_blocking=True)
|
||||||
|
vw = vw.to(device=dev, non_blocking=True)
|
||||||
|
rw = rw.to(device=dev, non_blocking=True)
|
||||||
|
ow = ow.to(device=dev, non_blocking=True)
|
||||||
|
kmx = w[f"{att}key.weight_mx"] if wtype == torch.uint8 else x
|
||||||
|
krx = w[f"{att}key.weight_rx"] if wtype == torch.uint8 else x
|
||||||
|
kmy = w[f"{att}key.weight_my"] if wtype == torch.uint8 else x
|
||||||
|
kry = w[f"{att}key.weight_ry"] if wtype == torch.uint8 else x
|
||||||
|
vmx = w[f"{att}value.weight_mx"] if wtype == torch.uint8 else x
|
||||||
|
vrx = w[f"{att}value.weight_rx"] if wtype == torch.uint8 else x
|
||||||
|
vmy = w[f"{att}value.weight_my"] if wtype == torch.uint8 else x
|
||||||
|
vry = w[f"{att}value.weight_ry"] if wtype == torch.uint8 else x
|
||||||
|
rmx = w[f"{att}receptance.weight_mx"] if wtype == torch.uint8 else x
|
||||||
|
rrx = w[f"{att}receptance.weight_rx"] if wtype == torch.uint8 else x
|
||||||
|
rmy = w[f"{att}receptance.weight_my"] if wtype == torch.uint8 else x
|
||||||
|
rry = w[f"{att}receptance.weight_ry"] if wtype == torch.uint8 else x
|
||||||
|
omx = w[f"{att}output.weight_mx"] if wtype == torch.uint8 else x
|
||||||
|
orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
|
||||||
|
omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
|
||||||
|
ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
|
||||||
|
(
|
||||||
|
x,
|
||||||
|
state[i * 5 + 0],
|
||||||
|
state[i * 5 + 1],
|
||||||
|
state[i * 5 + 2],
|
||||||
|
state[i * 5 + 3],
|
||||||
|
) = ATT(
|
||||||
|
x,
|
||||||
|
state[i * 5 + 0],
|
||||||
|
state[i * 5 + 1],
|
||||||
|
state[i * 5 + 2],
|
||||||
|
state[i * 5 + 3],
|
||||||
|
w[f"{bbb}ln1.weight"],
|
||||||
|
w[f"{bbb}ln1.bias"],
|
||||||
|
w[f"{att}time_mix_k"],
|
||||||
|
w[f"{att}time_mix_v"],
|
||||||
|
w[f"{att}time_mix_r"],
|
||||||
|
w[f"{att}time_decay"],
|
||||||
|
w[f"{att}time_first"],
|
||||||
|
kw,
|
||||||
|
vw,
|
||||||
|
rw,
|
||||||
|
ow,
|
||||||
|
kmx,
|
||||||
|
krx,
|
||||||
|
kmy,
|
||||||
|
kry,
|
||||||
|
vmx,
|
||||||
|
vrx,
|
||||||
|
vmy,
|
||||||
|
vry,
|
||||||
|
rmx,
|
||||||
|
rrx,
|
||||||
|
rmy,
|
||||||
|
rry,
|
||||||
|
omx,
|
||||||
|
orx,
|
||||||
|
omy,
|
||||||
|
ory,
|
||||||
|
)
|
||||||
|
|
||||||
|
return state[0].tolist()
|
||||||
|
|
||||||
def generate(self, prompt: str, stop: str = None):
|
def generate(self, prompt: str, stop: str = None):
|
||||||
quick_log(None, None, "Generation Prompt:\n" + prompt)
|
quick_log(None, None, "Generation Prompt:\n" + prompt)
|
||||||
cache = None
|
cache = None
|
||||||
|
Loading…
Reference in New Issue
Block a user