preliminary usable features

This commit is contained in:
josc146
2023-05-17 11:39:00 +08:00
parent 53502a8c3d
commit c947052574
21 changed files with 1187 additions and 1080 deletions

View File

@@ -1,4 +1,6 @@
import asyncio
import json
from threading import Lock
from typing import List
from fastapi import APIRouter, Request, status, HTTPException
@@ -15,46 +17,99 @@ class Message(BaseModel):
content: str
class CompletionBody(BaseModel):
class CompletionBody(ModelConfigBody):
messages: List[Message]
model: str
stream: bool
max_tokens: int
completion_lock = Lock()
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
async def completions(body: CompletionBody, request: Request):
model = global_var.get(global_var.Model)
if (model is None):
model: RWKV = global_var.get(global_var.Model)
if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
question = body.messages[-1]
if question.role == 'user':
if question.role == "user":
question = question.content
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
completion_text = ""
for message in body.messages:
if message.role == 'user':
if message.role == "user":
completion_text += "Bob: " + message.content + "\n\n"
elif message.role == 'assistant':
elif message.role == "assistant":
completion_text += "Alice: " + message.content + "\n\n"
completion_text += "Alice:"
async def eval_rwkv():
if body.stream:
for response, delta in rwkv_generate(model, completion_text, stop="Bob:"):
if await request.is_disconnected():
break
yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
yield "[DONE]"
while completion_lock.locked():
await asyncio.sleep(0.1)
else:
response = None
for response, delta in rwkv_generate(model, completion_text, stop="Bob:"):
pass
yield json.dumps({"response": response, "model": "rwkv"})
# torch_gc()
with completion_lock:
set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body)
if body.stream:
for response, delta in rwkv_generate(
model, completion_text, stop="Bob:"
):
if await request.is_disconnected():
break
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"delta": {"content": delta},
"index": 0,
"finish_reason": None,
}
],
}
)
yield json.dumps(
{
"response": response,
"model": "rwkv",
"choices": [
{
"delta": {},
"index": 0,
"finish_reason": "stop",
}
],
}
)
yield "[DONE]"
else:
response = None
for response, delta in rwkv_generate(
model, completion_text, stop="Bob:"
):
pass
yield {
"response": response,
"model": "rwkv",
"choices": [
{
"message": {
"role": "assistant",
"content": response,
},
"index": 0,
"finish_reason": "stop",
}
],
}
# torch_gc()
return EventSourceResponse(eval_rwkv())
if body.stream:
return EventSourceResponse(eval_rwkv())
else:
return await eval_rwkv().__anext__()

View File

@@ -1,5 +1,4 @@
import pathlib
import sys
from fastapi import APIRouter, HTTPException, Response, status
from pydantic import BaseModel
@@ -11,19 +10,14 @@ import global_var
router = APIRouter()
class UpdateConfigBody(BaseModel):
model: str = None
strategy: str = None
max_response_token: int = None
temperature: float = None
top_p: float = None
presence_penalty: float = None
count_penalty: float = None
class SwitchModelBody(BaseModel):
model: str
strategy: str
@router.post("/update-config")
def update_config(body: UpdateConfigBody, response: Response):
if (global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading):
@router.post("/switch-model")
def switch_model(body: SwitchModelBody, response: Response):
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
response.status_code = status.HTTP_304_NOT_MODIFIED
return
@@ -33,15 +27,34 @@ def update_config(body: UpdateConfigBody, response: Response):
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
try:
global_var.set(global_var.Model, RWKV(
model=sys.argv[2],
strategy=sys.argv[1],
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
))
global_var.set(
global_var.Model,
RWKV(
model=body.model,
strategy=body.strategy,
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json",
),
)
except Exception:
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")
if global_var.get(global_var.Model_Config) is None:
global_var.set(
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
)
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
return "success"
@router.post("/update-config")
def update_config(body: ModelConfigBody):
"""
Will not update the model config immediately, but set it when completion called to avoid modifications during generation
"""
print(body)
global_var.set(global_var.Model_Config, body)
return "success"