backend api

This commit is contained in:
josc146 2023-05-07 17:27:54 +08:00
parent 4795514e8f
commit 0e852daf43
8 changed files with 188 additions and 102 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
build/bin build/bin
node_modules node_modules
frontend/dist frontend/dist
__pycache__
.idea .idea
.vs .vs
package.json.md5 package.json.md5

View File

@ -0,0 +1,27 @@
from enum import Enum, auto
Model = 'model'
Model_Status = 'model_status'
class ModelStatus(Enum):
Offline = auto()
Loading = auto()
Working = auto()
def init():
global GLOBALS
GLOBALS = {}
set(Model_Status, ModelStatus.Offline)
def set(key, value):
GLOBALS[key] = value
def get(key):
if key in GLOBALS:
return GLOBALS[key]
else:
return None

View File

@ -1,42 +1,15 @@
import json
import pathlib
import sys
from typing import List
import os import os
import sysconfig import psutil
from fastapi import FastAPI, Request, status, HTTPException from fastapi import FastAPI
from langchain.llms import RWKV
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import uvicorn import uvicorn
from rwkv_helper import rwkv_generate from utils.rwkv import *
from utils.torch import *
from utils.ngrok import *
def set_torch(): from routes import completion, config
torch_path = os.path.join(sysconfig.get_paths()["purelib"], "torch\\lib") import global_var
paths = os.environ.get("PATH", "")
if os.path.exists(torch_path):
print(f"torch found: {torch_path}")
if torch_path in paths:
print("torch already set")
else:
print("run:")
os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep
print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}')
else:
print("torch not found")
def torch_gc():
import torch
if torch.cuda.is_available():
with torch.cuda.device(0):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI() app = FastAPI()
@ -49,87 +22,33 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(completion.router)
app.include_router(config.router)
@app.on_event('startup') @app.on_event('startup')
def init(): def init():
global model global_var.init()
set_torch() set_torch()
model = RWKV(
model=sys.argv[2],
strategy=sys.argv[1],
tokens_path=f"{pathlib.Path(__file__).parent.resolve()}/20B_tokenizer.json"
)
if os.environ.get("ngrok_token") is not None: if os.environ.get("ngrok_token") is not None:
ngrok_connect() ngrok_connect()
def ngrok_connect():
from pyngrok import ngrok, conf
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
ngrok.set_auth_token(os.environ["ngrok_token"])
http_tunnel = ngrok.connect(8000)
print(http_tunnel.public_url)
class Message(BaseModel):
role: str
content: str
class Body(BaseModel):
messages: List[Message]
model: str
stream: bool
max_tokens: int
@app.get("/") @app.get("/")
def read_root(): def read_root():
return {"Hello": "World!"} return {"Hello": "World!"}
@app.post("update-config")
def updateConfig(body: Body):
pass
@app.post("/exit")
@app.post("/v1/chat/completions") def read_root():
@app.post("/chat/completions") parent_pid = os.getpid()
async def completions(body: Body, request: Request): parent = psutil.Process(parent_pid)
global model for child in parent.children(recursive=True):
child.kill()
question = body.messages[-1] parent.kill()
if question.role == 'user':
question = question.content
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")
completion_text = ""
for message in body.messages:
if message.role == 'user':
completion_text += "Bob: " + message.content + "\n\n"
elif message.role == 'assistant':
completion_text += "Alice: " + message.content + "\n\n"
completion_text += "Alice:"
async def eval_rwkv():
if body.stream:
for response, delta in rwkv_generate(model, completion_text):
if await request.is_disconnected():
break
yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(model, completion_text):
pass
yield json.dumps({"response": response, "model": "rwkv"})
# torch_gc()
return EventSourceResponse(eval_rwkv())
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run("main:app", reload=False, app_dir="backend-python") uvicorn.run("main:app", port=8000)

View File

@ -0,0 +1,58 @@
import json
from typing import List
from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
from utils.rwkv import *
import global_var
router = APIRouter()
class Message(BaseModel):
role: str
content: str
class CompletionBody(BaseModel):
messages: List[Message]
model: str
stream: bool
max_tokens: int
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
async def completions(body: CompletionBody, request: Request):
model = global_var.get(global_var.Model)
question = body.messages[-1]
if question.role == 'user':
question = question.content
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
completion_text = ""
for message in body.messages:
if message.role == 'user':
completion_text += "Bob: " + message.content + "\n\n"
elif message.role == 'assistant':
completion_text += "Alice: " + message.content + "\n\n"
completion_text += "Alice:"
async def eval_rwkv():
if body.stream:
for response, delta in rwkv_generate(model, completion_text):
if await request.is_disconnected():
break
yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(model, completion_text):
pass
yield json.dumps({"response": response, "model": "rwkv"})
# torch_gc()
return EventSourceResponse(eval_rwkv())

View File

@ -0,0 +1,46 @@
import pathlib
import sys
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel
from langchain.llms import RWKV
from utils.rwkv import *
from utils.torch import *
import global_var
router = APIRouter()
class UpdateConfigBody(BaseModel):
model: str = None
strategy: str = None
max_response_token: int = None
temperature: float = None
top_p: float = None
presence_penalty: float = None
count_penalty: float = None
@router.post("/update-config")
def update_config(body: UpdateConfigBody):
if (global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading):
return "loading"
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
global_var.set(global_var.Model, None)
torch_gc()
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
try:
global_var.set(global_var.Model, RWKV(
model=sys.argv[2],
strategy=sys.argv[1],
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
))
except Exception:
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
return "success"

View File

@ -0,0 +1,9 @@
import os
def ngrok_connect():
from pyngrok import ngrok, conf
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
ngrok.set_auth_token(os.environ["ngrok_token"])
http_tunnel = ngrok.connect(8000)
print(http_tunnel.public_url)

View File

@ -0,0 +1,26 @@
import os
import sysconfig
def set_torch():
torch_path = os.path.join(sysconfig.get_paths()["purelib"], "torch\\lib")
paths = os.environ.get("PATH", "")
if os.path.exists(torch_path):
print(f"torch found: {torch_path}")
if torch_path in paths:
print("torch already set")
else:
print("run:")
os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep
print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}')
else:
print("torch not found")
def torch_gc():
import torch
if torch.cuda.is_available():
with torch.cuda.device(0):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()