136 lines
3.4 KiB
Python
136 lines
3.4 KiB
Python
import json
|
|
import pathlib
|
|
import sys
|
|
from typing import List
|
|
import os
|
|
import sysconfig
|
|
|
|
from fastapi import FastAPI, Request, status, HTTPException
|
|
from langchain.llms import RWKV
|
|
from pydantic import BaseModel
|
|
from sse_starlette.sse import EventSourceResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import uvicorn
|
|
|
|
from rwkv_helper import rwkv_generate
|
|
|
|
|
|
def set_torch():
|
|
torch_path = os.path.join(sysconfig.get_paths()["purelib"], "torch\\lib")
|
|
paths = os.environ.get("PATH", "")
|
|
if os.path.exists(torch_path):
|
|
print(f"torch found: {torch_path}")
|
|
if torch_path in paths:
|
|
print("torch already set")
|
|
else:
|
|
print("run:")
|
|
os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep
|
|
print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}')
|
|
else:
|
|
print("torch not found")
|
|
|
|
|
|
def torch_gc():
|
|
import torch
|
|
|
|
if torch.cuda.is_available():
|
|
with torch.cuda.device(0):
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.on_event('startup')
|
|
def init():
|
|
global model
|
|
|
|
set_torch()
|
|
|
|
model = RWKV(
|
|
model=sys.argv[2],
|
|
strategy=sys.argv[1],
|
|
tokens_path=f"{pathlib.Path(__file__).parent.resolve()}/20B_tokenizer.json"
|
|
)
|
|
|
|
if os.environ.get("ngrok_token") is not None:
|
|
ngrok_connect()
|
|
|
|
|
|
def ngrok_connect():
|
|
from pyngrok import ngrok, conf
|
|
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
|
|
ngrok.set_auth_token(os.environ["ngrok_token"])
|
|
http_tunnel = ngrok.connect(8000)
|
|
print(http_tunnel.public_url)
|
|
|
|
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
|
|
class Body(BaseModel):
|
|
messages: List[Message]
|
|
model: str
|
|
stream: bool
|
|
max_tokens: int
|
|
|
|
|
|
@app.get("/")
|
|
def read_root():
|
|
return {"Hello": "World!"}
|
|
|
|
@app.post("update-config")
|
|
def updateConfig(body: Body):
|
|
pass
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
@app.post("/chat/completions")
|
|
async def completions(body: Body, request: Request):
|
|
global model
|
|
|
|
question = body.messages[-1]
|
|
if question.role == 'user':
|
|
question = question.content
|
|
else:
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")
|
|
|
|
completion_text = ""
|
|
for message in body.messages:
|
|
if message.role == 'user':
|
|
completion_text += "Bob: " + message.content + "\n\n"
|
|
elif message.role == 'assistant':
|
|
completion_text += "Alice: " + message.content + "\n\n"
|
|
completion_text += "Alice:"
|
|
|
|
async def eval_rwkv():
|
|
if body.stream:
|
|
for response, delta in rwkv_generate(model, completion_text):
|
|
if await request.is_disconnected():
|
|
break
|
|
yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
|
|
yield "[DONE]"
|
|
else:
|
|
response = None
|
|
for response, delta in rwkv_generate(model, completion_text):
|
|
pass
|
|
yield json.dumps({"response": response, "model": "rwkv"})
|
|
# torch_gc()
|
|
|
|
return EventSourceResponse(eval_rwkv())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run("main:app", reload=False, app_dir="backend-python")
|