From 0e852daf436d2aa9bcec6f1592ba64496008f639 Mon Sep 17 00:00:00 2001 From: josc146 Date: Sun, 7 May 2023 17:27:54 +0800 Subject: [PATCH] backend api --- .gitignore | 1 + backend-python/global_var.py | 27 ++++ backend-python/main.py | 119 +++--------------- backend-python/routes/completion.py | 58 +++++++++ backend-python/routes/config.py | 46 +++++++ backend-python/utils/ngrok.py | 9 ++ .../{rwkv_helper.py => utils/rwkv.py} | 4 +- backend-python/utils/torch.py | 26 ++++ 8 files changed, 188 insertions(+), 102 deletions(-) create mode 100644 backend-python/global_var.py create mode 100644 backend-python/routes/completion.py create mode 100644 backend-python/routes/config.py create mode 100644 backend-python/utils/ngrok.py rename backend-python/{rwkv_helper.py => utils/rwkv.py} (90%) create mode 100644 backend-python/utils/torch.py diff --git a/.gitignore b/.gitignore index 07c9d43..9aeb554 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ build/bin node_modules frontend/dist +__pycache__ .idea .vs package.json.md5 diff --git a/backend-python/global_var.py b/backend-python/global_var.py new file mode 100644 index 0000000..7a3b879 --- /dev/null +++ b/backend-python/global_var.py @@ -0,0 +1,27 @@ +from enum import Enum, auto + +Model = 'model' +Model_Status = 'model_status' + + +class ModelStatus(Enum): + Offline = auto() + Loading = auto() + Working = auto() + + +def init(): + global GLOBALS + GLOBALS = {} + set(Model_Status, ModelStatus.Offline) + + +def set(key, value): + GLOBALS[key] = value + + +def get(key): + if key in GLOBALS: + return GLOBALS[key] + else: + return None diff --git a/backend-python/main.py b/backend-python/main.py index 9ae3861..20d3837 100644 --- a/backend-python/main.py +++ b/backend-python/main.py @@ -1,42 +1,15 @@ -import json -import pathlib -import sys -from typing import List import os -import sysconfig +import psutil -from fastapi import FastAPI, Request, status, HTTPException -from langchain.llms import RWKV -from pydantic import BaseModel -from sse_starlette.sse import EventSourceResponse +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware import uvicorn -from rwkv_helper import rwkv_generate - - -def set_torch(): - torch_path = os.path.join(sysconfig.get_paths()["purelib"], "torch\\lib") - paths = os.environ.get("PATH", "") - if os.path.exists(torch_path): - print(f"torch found: {torch_path}") - if torch_path in paths: - print("torch already set") - else: - print("run:") - os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep - print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}') - else: - print("torch not found") - - -def torch_gc(): - import torch - - if torch.cuda.is_available(): - with torch.cuda.device(0): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() +from utils.rwkv import * +from utils.torch import * +from utils.ngrok import * +from routes import completion, config +import global_var app = FastAPI() @@ -49,87 +22,33 @@ app.add_middleware( allow_headers=["*"], ) +app.include_router(completion.router) +app.include_router(config.router) + @app.on_event('startup') def init(): - global model + global_var.init() set_torch() - model = RWKV( - model=sys.argv[2], - strategy=sys.argv[1], - tokens_path=f"{pathlib.Path(__file__).parent.resolve()}/20B_tokenizer.json" - ) - if os.environ.get("ngrok_token") is not None: ngrok_connect() -def ngrok_connect(): - from pyngrok import ngrok, conf - conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok")) - ngrok.set_auth_token(os.environ["ngrok_token"]) - http_tunnel = ngrok.connect(8000) - print(http_tunnel.public_url) - - -class Message(BaseModel): - role: str - content: str - - -class Body(BaseModel): - messages: List[Message] - model: str - stream: bool - max_tokens: int - - @app.get("/") def read_root(): return {"Hello": "World!"} -@app.post("update-config") -def updateConfig(body: Body): - pass - -@app.post("/v1/chat/completions") -@app.post("/chat/completions") -async def completions(body: Body, request: Request): - global model - - question = body.messages[-1] - if question.role == 'user': - question = question.content - else: - raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found") - - completion_text = "" - for message in body.messages: - if message.role == 'user': - completion_text += "Bob: " + message.content + "\n\n" - elif message.role == 'assistant': - completion_text += "Alice: " + message.content + "\n\n" - completion_text += "Alice:" - - async def eval_rwkv(): - if body.stream: - for response, delta in rwkv_generate(model, completion_text): - if await request.is_disconnected(): - break - yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"}) - yield "[DONE]" - else: - response = None - for response, delta in rwkv_generate(model, completion_text): - pass - yield json.dumps({"response": response, "model": "rwkv"}) - # torch_gc() - - return EventSourceResponse(eval_rwkv()) +@app.post("/exit") +def read_root(): + parent_pid = os.getpid() + parent = psutil.Process(parent_pid) + for child in parent.children(recursive=True): + child.kill() + parent.kill() if __name__ == "__main__": - uvicorn.run("main:app", reload=False, app_dir="backend-python") + uvicorn.run("main:app", port=8000) diff --git a/backend-python/routes/completion.py b/backend-python/routes/completion.py new file mode 100644 index 0000000..7dd4307 --- /dev/null +++ b/backend-python/routes/completion.py @@ -0,0 +1,58 @@ +import json +from typing import List + +from fastapi import APIRouter, Request, status, HTTPException +from sse_starlette.sse import EventSourceResponse +from pydantic import BaseModel +from utils.rwkv import * +import global_var + +router = APIRouter() + + +class Message(BaseModel): + role: str + content: str + + +class CompletionBody(BaseModel): + messages: List[Message] + model: str + stream: bool + max_tokens: int + + +@router.post("/v1/chat/completions") +@router.post("/chat/completions") +async def completions(body: CompletionBody, request: Request): + model = global_var.get(global_var.Model) + + question = body.messages[-1] + if question.role == 'user': + question = question.content + else: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found") + + completion_text = "" + for message in body.messages: + if message.role == 'user': + completion_text += "Bob: " + message.content + "\n\n" + elif message.role == 'assistant': + completion_text += "Alice: " + message.content + "\n\n" + completion_text += "Alice:" + + async def eval_rwkv(): + if body.stream: + for response, delta in rwkv_generate(model, completion_text): + if await request.is_disconnected(): + break + yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"}) + yield "[DONE]" + else: + response = None + for response, delta in rwkv_generate(model, completion_text): + pass + yield json.dumps({"response": response, "model": "rwkv"}) + # torch_gc() + + return EventSourceResponse(eval_rwkv()) diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py new file mode 100644 index 0000000..d65613a --- /dev/null +++ b/backend-python/routes/config.py @@ -0,0 +1,46 @@ +import pathlib +import sys + +from fastapi import APIRouter, HTTPException, status +from pydantic import BaseModel +from langchain.llms import RWKV +from utils.rwkv import * +from utils.torch import * +import global_var + +router = APIRouter() + + +class UpdateConfigBody(BaseModel): + model: str = None + strategy: str = None + max_response_token: int = None + temperature: float = None + top_p: float = None + presence_penalty: float = None + count_penalty: float = None + + +@router.post("/update-config") +def update_config(body: UpdateConfigBody): + if (global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading): + return "loading" + + global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline) + global_var.set(global_var.Model, None) + torch_gc() + + global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading) + try: + global_var.set(global_var.Model, RWKV( + model=sys.argv[2], + strategy=sys.argv[1], + tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json" + )) + except Exception: + global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline) + raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load") + + global_var.set(global_var.Model_Status, global_var.ModelStatus.Working) + + return "success" diff --git a/backend-python/utils/ngrok.py b/backend-python/utils/ngrok.py new file mode 100644 index 0000000..463fcd0 --- /dev/null +++ b/backend-python/utils/ngrok.py @@ -0,0 +1,9 @@ +import os + + +def ngrok_connect(): + from pyngrok import ngrok, conf + conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok")) + ngrok.set_auth_token(os.environ["ngrok_token"]) + http_tunnel = ngrok.connect(8000) + print(http_tunnel.public_url) diff --git a/backend-python/rwkv_helper.py b/backend-python/utils/rwkv.py similarity index 90% rename from backend-python/rwkv_helper.py rename to backend-python/utils/rwkv.py index 812531c..2bd0cb7 100644 --- a/backend-python/rwkv_helper.py +++ b/backend-python/utils/rwkv.py @@ -15,8 +15,8 @@ def rwkv_generate(model: RWKV, prompt: str): for i in range(model.max_tokens_per_generation): for n in occurrence: logits[n] -= ( - model.penalty_alpha_presence - + occurrence[n] * model.penalty_alpha_frequency + model.penalty_alpha_presence + + occurrence[n] * model.penalty_alpha_frequency ) token = model.pipeline.sample_logits( logits, temperature=model.temperature, top_p=model.top_p diff --git a/backend-python/utils/torch.py b/backend-python/utils/torch.py new file mode 100644 index 0000000..21473c4 --- /dev/null +++ b/backend-python/utils/torch.py @@ -0,0 +1,26 @@ +import os +import sysconfig + + +def set_torch(): + torch_path = os.path.join(sysconfig.get_paths()["purelib"], "torch\\lib") + paths = os.environ.get("PATH", "") + if os.path.exists(torch_path): + print(f"torch found: {torch_path}") + if torch_path in paths: + print("torch already set") + else: + print("run:") + os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep + print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}') + else: + print("torch not found") + + +def torch_gc(): + import torch + + if torch.cuda.is_available(): + with torch.cuda.device(0): + torch.cuda.empty_cache() + torch.cuda.ipc_collect()