backend api
This commit is contained in:
parent
4795514e8f
commit
0e852daf43
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,6 +1,7 @@
|
|||||||
build/bin
|
build/bin
|
||||||
node_modules
|
node_modules
|
||||||
frontend/dist
|
frontend/dist
|
||||||
|
__pycache__
|
||||||
.idea
|
.idea
|
||||||
.vs
|
.vs
|
||||||
package.json.md5
|
package.json.md5
|
||||||
|
27
backend-python/global_var.py
Normal file
27
backend-python/global_var.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from enum import Enum, auto
|
||||||
|
|
||||||
|
Model = 'model'
|
||||||
|
Model_Status = 'model_status'
|
||||||
|
|
||||||
|
|
||||||
|
class ModelStatus(Enum):
|
||||||
|
Offline = auto()
|
||||||
|
Loading = auto()
|
||||||
|
Working = auto()
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
global GLOBALS
|
||||||
|
GLOBALS = {}
|
||||||
|
set(Model_Status, ModelStatus.Offline)
|
||||||
|
|
||||||
|
|
||||||
|
def set(key, value):
|
||||||
|
GLOBALS[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
def get(key):
|
||||||
|
if key in GLOBALS:
|
||||||
|
return GLOBALS[key]
|
||||||
|
else:
|
||||||
|
return None
|
@ -1,42 +1,15 @@
|
|||||||
import json
|
|
||||||
import pathlib
|
|
||||||
import sys
|
|
||||||
from typing import List
|
|
||||||
import os
|
import os
|
||||||
import sysconfig
|
import psutil
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, status, HTTPException
|
from fastapi import FastAPI
|
||||||
from langchain.llms import RWKV
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sse_starlette.sse import EventSourceResponse
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from rwkv_helper import rwkv_generate
|
from utils.rwkv import *
|
||||||
|
from utils.torch import *
|
||||||
|
from utils.ngrok import *
|
||||||
def set_torch():
|
from routes import completion, config
|
||||||
torch_path = os.path.join(sysconfig.get_paths()["purelib"], "torch\\lib")
|
import global_var
|
||||||
paths = os.environ.get("PATH", "")
|
|
||||||
if os.path.exists(torch_path):
|
|
||||||
print(f"torch found: {torch_path}")
|
|
||||||
if torch_path in paths:
|
|
||||||
print("torch already set")
|
|
||||||
else:
|
|
||||||
print("run:")
|
|
||||||
os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep
|
|
||||||
print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}')
|
|
||||||
else:
|
|
||||||
print("torch not found")
|
|
||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
with torch.cuda.device(0):
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
@ -49,87 +22,33 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app.include_router(completion.router)
|
||||||
|
app.include_router(config.router)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event('startup')
|
@app.on_event('startup')
|
||||||
def init():
|
def init():
|
||||||
global model
|
global_var.init()
|
||||||
|
|
||||||
set_torch()
|
set_torch()
|
||||||
|
|
||||||
model = RWKV(
|
|
||||||
model=sys.argv[2],
|
|
||||||
strategy=sys.argv[1],
|
|
||||||
tokens_path=f"{pathlib.Path(__file__).parent.resolve()}/20B_tokenizer.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
if os.environ.get("ngrok_token") is not None:
|
if os.environ.get("ngrok_token") is not None:
|
||||||
ngrok_connect()
|
ngrok_connect()
|
||||||
|
|
||||||
|
|
||||||
def ngrok_connect():
|
|
||||||
from pyngrok import ngrok, conf
|
|
||||||
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
|
|
||||||
ngrok.set_auth_token(os.environ["ngrok_token"])
|
|
||||||
http_tunnel = ngrok.connect(8000)
|
|
||||||
print(http_tunnel.public_url)
|
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class Body(BaseModel):
|
|
||||||
messages: List[Message]
|
|
||||||
model: str
|
|
||||||
stream: bool
|
|
||||||
max_tokens: int
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
def read_root():
|
def read_root():
|
||||||
return {"Hello": "World!"}
|
return {"Hello": "World!"}
|
||||||
|
|
||||||
@app.post("update-config")
|
|
||||||
def updateConfig(body: Body):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
@app.post("/exit")
|
||||||
@app.post("/v1/chat/completions")
|
def read_root():
|
||||||
@app.post("/chat/completions")
|
parent_pid = os.getpid()
|
||||||
async def completions(body: Body, request: Request):
|
parent = psutil.Process(parent_pid)
|
||||||
global model
|
for child in parent.children(recursive=True):
|
||||||
|
child.kill()
|
||||||
question = body.messages[-1]
|
parent.kill()
|
||||||
if question.role == 'user':
|
|
||||||
question = question.content
|
|
||||||
else:
|
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")
|
|
||||||
|
|
||||||
completion_text = ""
|
|
||||||
for message in body.messages:
|
|
||||||
if message.role == 'user':
|
|
||||||
completion_text += "Bob: " + message.content + "\n\n"
|
|
||||||
elif message.role == 'assistant':
|
|
||||||
completion_text += "Alice: " + message.content + "\n\n"
|
|
||||||
completion_text += "Alice:"
|
|
||||||
|
|
||||||
async def eval_rwkv():
|
|
||||||
if body.stream:
|
|
||||||
for response, delta in rwkv_generate(model, completion_text):
|
|
||||||
if await request.is_disconnected():
|
|
||||||
break
|
|
||||||
yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
|
|
||||||
yield "[DONE]"
|
|
||||||
else:
|
|
||||||
response = None
|
|
||||||
for response, delta in rwkv_generate(model, completion_text):
|
|
||||||
pass
|
|
||||||
yield json.dumps({"response": response, "model": "rwkv"})
|
|
||||||
# torch_gc()
|
|
||||||
|
|
||||||
return EventSourceResponse(eval_rwkv())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run("main:app", reload=False, app_dir="backend-python")
|
uvicorn.run("main:app", port=8000)
|
||||||
|
58
backend-python/routes/completion.py
Normal file
58
backend-python/routes/completion.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request, status, HTTPException
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from utils.rwkv import *
|
||||||
|
import global_var
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionBody(BaseModel):
|
||||||
|
messages: List[Message]
|
||||||
|
model: str
|
||||||
|
stream: bool
|
||||||
|
max_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/chat/completions")
|
||||||
|
@router.post("/chat/completions")
|
||||||
|
async def completions(body: CompletionBody, request: Request):
|
||||||
|
model = global_var.get(global_var.Model)
|
||||||
|
|
||||||
|
question = body.messages[-1]
|
||||||
|
if question.role == 'user':
|
||||||
|
question = question.content
|
||||||
|
else:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
|
||||||
|
|
||||||
|
completion_text = ""
|
||||||
|
for message in body.messages:
|
||||||
|
if message.role == 'user':
|
||||||
|
completion_text += "Bob: " + message.content + "\n\n"
|
||||||
|
elif message.role == 'assistant':
|
||||||
|
completion_text += "Alice: " + message.content + "\n\n"
|
||||||
|
completion_text += "Alice:"
|
||||||
|
|
||||||
|
async def eval_rwkv():
|
||||||
|
if body.stream:
|
||||||
|
for response, delta in rwkv_generate(model, completion_text):
|
||||||
|
if await request.is_disconnected():
|
||||||
|
break
|
||||||
|
yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
|
||||||
|
yield "[DONE]"
|
||||||
|
else:
|
||||||
|
response = None
|
||||||
|
for response, delta in rwkv_generate(model, completion_text):
|
||||||
|
pass
|
||||||
|
yield json.dumps({"response": response, "model": "rwkv"})
|
||||||
|
# torch_gc()
|
||||||
|
|
||||||
|
return EventSourceResponse(eval_rwkv())
|
46
backend-python/routes/config.py
Normal file
46
backend-python/routes/config.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from langchain.llms import RWKV
|
||||||
|
from utils.rwkv import *
|
||||||
|
from utils.torch import *
|
||||||
|
import global_var
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateConfigBody(BaseModel):
|
||||||
|
model: str = None
|
||||||
|
strategy: str = None
|
||||||
|
max_response_token: int = None
|
||||||
|
temperature: float = None
|
||||||
|
top_p: float = None
|
||||||
|
presence_penalty: float = None
|
||||||
|
count_penalty: float = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/update-config")
|
||||||
|
def update_config(body: UpdateConfigBody):
|
||||||
|
if (global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading):
|
||||||
|
return "loading"
|
||||||
|
|
||||||
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
|
||||||
|
global_var.set(global_var.Model, None)
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
||||||
|
try:
|
||||||
|
global_var.set(global_var.Model, RWKV(
|
||||||
|
model=sys.argv[2],
|
||||||
|
strategy=sys.argv[1],
|
||||||
|
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
|
||||||
|
))
|
||||||
|
except Exception:
|
||||||
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
|
||||||
|
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")
|
||||||
|
|
||||||
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
|
||||||
|
|
||||||
|
return "success"
|
9
backend-python/utils/ngrok.py
Normal file
9
backend-python/utils/ngrok.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def ngrok_connect():
|
||||||
|
from pyngrok import ngrok, conf
|
||||||
|
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
|
||||||
|
ngrok.set_auth_token(os.environ["ngrok_token"])
|
||||||
|
http_tunnel = ngrok.connect(8000)
|
||||||
|
print(http_tunnel.public_url)
|
@ -15,8 +15,8 @@ def rwkv_generate(model: RWKV, prompt: str):
|
|||||||
for i in range(model.max_tokens_per_generation):
|
for i in range(model.max_tokens_per_generation):
|
||||||
for n in occurrence:
|
for n in occurrence:
|
||||||
logits[n] -= (
|
logits[n] -= (
|
||||||
model.penalty_alpha_presence
|
model.penalty_alpha_presence
|
||||||
+ occurrence[n] * model.penalty_alpha_frequency
|
+ occurrence[n] * model.penalty_alpha_frequency
|
||||||
)
|
)
|
||||||
token = model.pipeline.sample_logits(
|
token = model.pipeline.sample_logits(
|
||||||
logits, temperature=model.temperature, top_p=model.top_p
|
logits, temperature=model.temperature, top_p=model.top_p
|
26
backend-python/utils/torch.py
Normal file
26
backend-python/utils/torch.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import os
|
||||||
|
import sysconfig
|
||||||
|
|
||||||
|
|
||||||
|
def set_torch():
|
||||||
|
torch_path = os.path.join(sysconfig.get_paths()["purelib"], "torch\\lib")
|
||||||
|
paths = os.environ.get("PATH", "")
|
||||||
|
if os.path.exists(torch_path):
|
||||||
|
print(f"torch found: {torch_path}")
|
||||||
|
if torch_path in paths:
|
||||||
|
print("torch already set")
|
||||||
|
else:
|
||||||
|
print("run:")
|
||||||
|
os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep
|
||||||
|
print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}')
|
||||||
|
else:
|
||||||
|
print("torch not found")
|
||||||
|
|
||||||
|
|
||||||
|
def torch_gc():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
with torch.cuda.device(0):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
Loading…
Reference in New Issue
Block a user