This commit is contained in:
josc146 2023-06-03 17:12:59 +08:00
parent f2ec1067bf
commit 38b775c937
4 changed files with 102 additions and 6 deletions

2
.gitignore vendored
View File

@ -17,3 +17,5 @@ __pycache__
*.exe *.exe
*.old *.old
.DS_Store .DS_Store
*.log.*
*.log

View File

@ -4,17 +4,18 @@ import sys
sys.path.append(os.path.dirname(os.path.realpath(__file__))) sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import psutil import psutil
from fastapi import FastAPI from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import uvicorn import uvicorn
from utils.rwkv import * from utils.rwkv import *
from utils.torch import * from utils.torch import *
from utils.ngrok import * from utils.ngrok import *
from utils.log import log_middleware
from routes import completion, config, state_cache from routes import completion, config, state_cache
import global_var import global_var
app = FastAPI() app = FastAPI(dependencies=[Depends(log_middleware)])
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -42,7 +43,7 @@ def init():
@app.get("/") @app.get("/")
def read_root(): def read_root():
return {"Hello": "World!", "pid": os.getpid()} return {"Hello": "World!"}
@app.post("/exit") @app.post("/exit")
@ -60,7 +61,7 @@ def debug():
strategy="cuda fp16", strategy="cuda fp16",
tokens_path="20B_tokenizer.json", tokens_path="20B_tokenizer.json",
) )
d = model.tokenizer.decode([]) d = model.pipeline.decode([])
print(d) print(d)

View File

@ -7,6 +7,7 @@ from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel from pydantic import BaseModel
from utils.rwkv import * from utils.rwkv import *
from utils.log import quick_log
import global_var import global_var
router = APIRouter() router = APIRouter()
@ -26,6 +27,8 @@ class ChatCompletionBody(ModelConfigBody):
completion_lock = Lock() completion_lock = Lock()
requests_num = 0
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
@router.post("/chat/completions") @router.post("/chat/completions")
@ -106,8 +109,15 @@ The following is a coherent verbose detailed conversation between a girl named {
completion_text += f"{bot}{interface}" completion_text += f"{bot}{interface}"
async def eval_rwkv(): async def eval_rwkv():
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked(): while completion_lock.locked():
if await request.is_disconnected(): if await request.is_disconnected():
requests_num = requests_num - 1
quick_log(
request, None, "Stop Waiting. RequestsNum: " + str(requests_num)
)
return return
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
else: else:
@ -135,9 +145,21 @@ The following is a coherent verbose detailed conversation between a girl named {
} }
) )
# torch_gc() # torch_gc()
requests_num = requests_num - 1
completion_lock.release() completion_lock.release()
if await request.is_disconnected(): if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return return
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
yield json.dumps( yield json.dumps(
{ {
"response": response, "response": response,
@ -161,6 +183,12 @@ The following is a coherent verbose detailed conversation between a girl named {
if await request.is_disconnected(): if await request.is_disconnected():
break break
# torch_gc() # torch_gc()
requests_num = requests_num - 1
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
completion_lock.release() completion_lock.release()
if await request.is_disconnected(): if await request.is_disconnected():
return return
@ -182,7 +210,11 @@ The following is a coherent verbose detailed conversation between a girl named {
if body.stream: if body.stream:
return EventSourceResponse(eval_rwkv()) return EventSourceResponse(eval_rwkv())
else: else:
try:
return await eval_rwkv().__anext__() return await eval_rwkv().__anext__()
except StopAsyncIteration:
print(f"{request.client} Stop Waiting")
return None
class CompletionBody(ModelConfigBody): class CompletionBody(ModelConfigBody):
@ -203,8 +235,15 @@ async def completions(body: CompletionBody, request: Request):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found") raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found")
async def eval_rwkv(): async def eval_rwkv():
global requests_num
requests_num = requests_num + 1
quick_log(request, None, "Start Waiting. RequestsNum: " + str(requests_num))
while completion_lock.locked(): while completion_lock.locked():
if await request.is_disconnected(): if await request.is_disconnected():
requests_num = requests_num - 1
quick_log(
request, None, "Stop Waiting. RequestsNum: " + str(requests_num)
)
return return
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
else: else:
@ -229,9 +268,21 @@ async def completions(body: CompletionBody, request: Request):
} }
) )
# torch_gc() # torch_gc()
requests_num = requests_num - 1
completion_lock.release() completion_lock.release()
if await request.is_disconnected(): if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(
request,
body,
response + "\nStop Waiting. RequestsNum: " + str(requests_num),
)
return return
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
yield json.dumps( yield json.dumps(
{ {
"response": response, "response": response,
@ -252,6 +303,12 @@ async def completions(body: CompletionBody, request: Request):
if await request.is_disconnected(): if await request.is_disconnected():
break break
# torch_gc() # torch_gc()
requests_num = requests_num - 1
quick_log(
request,
body,
response + "\nFinished. RequestsNum: " + str(requests_num),
)
completion_lock.release() completion_lock.release()
if await request.is_disconnected(): if await request.is_disconnected():
return return
@ -270,4 +327,8 @@ async def completions(body: CompletionBody, request: Request):
if body.stream: if body.stream:
return EventSourceResponse(eval_rwkv()) return EventSourceResponse(eval_rwkv())
else: else:
try:
return await eval_rwkv().__anext__() return await eval_rwkv().__anext__()
except StopAsyncIteration:
print(f"{request.client} Stop Waiting")
return None

View File

@ -0,0 +1,32 @@
import json
import logging
from typing import Any
from fastapi import Request
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(levelname)s\n%(message)s")
fh = logging.handlers.RotatingFileHandler(
"api.log", mode="a", maxBytes=3 * 1024 * 1024, backupCount=3
)
fh.setFormatter(formatter)
logger.addHandler(fh)
def quick_log(request: Request, body: Any, response: str):
logger.info(
f"Client: {request.client}\nUrl: {request.url}\n"
+ (
f"Body: {json.dumps(body.__dict__, default=vars, ensure_ascii=False)}\n"
if body
else ""
)
+ (f"Response:\n{response}\n" if response else "")
)
async def log_middleware(request: Request):
logger.info(
f"Client: {request.client}\nUrl: {request.url}\nBody: {await request.body()}\n"
)