This commit is contained in:
josc146
2023-05-06 20:17:39 +08:00
parent f6be32825f
commit ac3e34e1d8
18 changed files with 101037 additions and 295 deletions

File diff suppressed because it is too large Load Diff

135
backend-python/main.py Normal file
View File

@@ -0,0 +1,135 @@
import json
import pathlib
import sys
from typing import List
import os
import sysconfig
from fastapi import FastAPI, Request, status, HTTPException
from langchain.llms import RWKV
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from rwkv_helper import rwkv_generate
def set_torch():
torch_path = os.path.join(sysconfig.get_paths()["purelib"], "torch\\lib")
paths = os.environ.get("PATH", "")
if os.path.exists(torch_path):
print(f"torch found: {torch_path}")
if torch_path in paths:
print("torch already set")
else:
print("run:")
os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep
print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}')
else:
print("torch not found")
def torch_gc():
import torch
if torch.cuda.is_available():
with torch.cuda.device(0):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event('startup')
def init():
global model
set_torch()
model = RWKV(
model=sys.argv[2],
strategy=sys.argv[1],
tokens_path=f"{pathlib.Path(__file__).parent.resolve()}/20B_tokenizer.json"
)
if os.environ.get("ngrok_token") is not None:
ngrok_connect()
def ngrok_connect():
from pyngrok import ngrok, conf
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
ngrok.set_auth_token(os.environ["ngrok_token"])
http_tunnel = ngrok.connect(8000)
print(http_tunnel.public_url)
class Message(BaseModel):
role: str
content: str
class Body(BaseModel):
messages: List[Message]
model: str
stream: bool
max_tokens: int
@app.get("/")
def read_root():
return {"Hello": "World!"}
@app.post("update-config")
def updateConfig(body: Body):
pass
@app.post("/v1/chat/completions")
@app.post("/chat/completions")
async def completions(body: Body, request: Request):
global model
question = body.messages[-1]
if question.role == 'user':
question = question.content
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")
completion_text = ""
for message in body.messages:
if message.role == 'user':
completion_text += "Bob: " + message.content + "\n\n"
elif message.role == 'assistant':
completion_text += "Alice: " + message.content + "\n\n"
completion_text += "Alice:"
async def eval_rwkv():
if body.stream:
for response, delta in rwkv_generate(model, completion_text):
if await request.is_disconnected():
break
yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(model, completion_text):
pass
yield json.dumps({"response": response, "model": "rwkv"})
# torch_gc()
return EventSourceResponse(eval_rwkv())
if __name__ == "__main__":
uvicorn.run("main:app", reload=False, app_dir="backend-python")

View File

@@ -0,0 +1,40 @@
from typing import Dict
from langchain.llms import RWKV
def rwkv_generate(model: RWKV, prompt: str):
model.model_state = None
model.model_tokens = []
logits = model.run_rnn(model.tokenizer.encode(prompt).ids)
begin = len(model.model_tokens)
out_last = begin
occurrence: Dict = {}
response = ""
for i in range(model.max_tokens_per_generation):
for n in occurrence:
logits[n] -= (
model.penalty_alpha_presence
+ occurrence[n] * model.penalty_alpha_frequency
)
token = model.pipeline.sample_logits(
logits, temperature=model.temperature, top_p=model.top_p
)
END_OF_TEXT = 0
if token == END_OF_TEXT:
break
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
logits = model.run_rnn([token])
delta: str = model.tokenizer.decode(model.model_tokens[out_last:])
if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta
yield response, delta
out_last = begin + i + 1
if i >= model.max_tokens_per_generation - 100:
break