preliminary usable features
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from enum import Enum, auto
|
||||
|
||||
Model = 'model'
|
||||
Model_Status = 'model_status'
|
||||
Model = "model"
|
||||
Model_Status = "model_status"
|
||||
Model_Config = "model_config"
|
||||
|
||||
|
||||
class ModelStatus(Enum):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import psutil
|
||||
import sys
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -26,7 +27,7 @@ app.include_router(completion.router)
|
||||
app.include_router(config.router)
|
||||
|
||||
|
||||
@app.on_event('startup')
|
||||
@app.on_event("startup")
|
||||
def init():
|
||||
global_var.init()
|
||||
|
||||
@@ -38,7 +39,7 @@ def init():
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {"Hello": "World!"}
|
||||
return {"Hello": "World!", "pid": os.getpid()}
|
||||
|
||||
|
||||
@app.post("/exit")
|
||||
@@ -51,4 +52,4 @@ def exit():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("main:app", port=8000)
|
||||
uvicorn.run("main:app", port=8000 if len(sys.argv) == 1 else int(sys.argv[1]))
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
from threading import Lock
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Request, status, HTTPException
|
||||
@@ -15,46 +17,99 @@ class Message(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class CompletionBody(BaseModel):
|
||||
class CompletionBody(ModelConfigBody):
|
||||
messages: List[Message]
|
||||
model: str
|
||||
stream: bool
|
||||
max_tokens: int
|
||||
|
||||
|
||||
completion_lock = Lock()
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@router.post("/chat/completions")
|
||||
async def completions(body: CompletionBody, request: Request):
|
||||
model = global_var.get(global_var.Model)
|
||||
if (model is None):
|
||||
model: RWKV = global_var.get(global_var.Model)
|
||||
if model is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
|
||||
|
||||
question = body.messages[-1]
|
||||
if question.role == 'user':
|
||||
if question.role == "user":
|
||||
question = question.content
|
||||
else:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
|
||||
|
||||
completion_text = ""
|
||||
for message in body.messages:
|
||||
if message.role == 'user':
|
||||
if message.role == "user":
|
||||
completion_text += "Bob: " + message.content + "\n\n"
|
||||
elif message.role == 'assistant':
|
||||
elif message.role == "assistant":
|
||||
completion_text += "Alice: " + message.content + "\n\n"
|
||||
completion_text += "Alice:"
|
||||
|
||||
async def eval_rwkv():
|
||||
if body.stream:
|
||||
for response, delta in rwkv_generate(model, completion_text, stop="Bob:"):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
|
||||
yield "[DONE]"
|
||||
while completion_lock.locked():
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
response = None
|
||||
for response, delta in rwkv_generate(model, completion_text, stop="Bob:"):
|
||||
pass
|
||||
yield json.dumps({"response": response, "model": "rwkv"})
|
||||
# torch_gc()
|
||||
with completion_lock:
|
||||
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
||||
set_rwkv_config(model, body)
|
||||
if body.stream:
|
||||
for response, delta in rwkv_generate(
|
||||
model, completion_text, stop="Bob:"
|
||||
):
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
yield json.dumps(
|
||||
{
|
||||
"response": response,
|
||||
"model": "rwkv",
|
||||
"choices": [
|
||||
{
|
||||
"delta": {"content": delta},
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
yield json.dumps(
|
||||
{
|
||||
"response": response,
|
||||
"model": "rwkv",
|
||||
"choices": [
|
||||
{
|
||||
"delta": {},
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
yield "[DONE]"
|
||||
else:
|
||||
response = None
|
||||
for response, delta in rwkv_generate(
|
||||
model, completion_text, stop="Bob:"
|
||||
):
|
||||
pass
|
||||
yield {
|
||||
"response": response,
|
||||
"model": "rwkv",
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
},
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
# torch_gc()
|
||||
|
||||
return EventSourceResponse(eval_rwkv())
|
||||
if body.stream:
|
||||
return EventSourceResponse(eval_rwkv())
|
||||
else:
|
||||
return await eval_rwkv().__anext__()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Response, status
|
||||
from pydantic import BaseModel
|
||||
@@ -11,19 +10,14 @@ import global_var
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class UpdateConfigBody(BaseModel):
|
||||
model: str = None
|
||||
strategy: str = None
|
||||
max_response_token: int = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
presence_penalty: float = None
|
||||
count_penalty: float = None
|
||||
class SwitchModelBody(BaseModel):
|
||||
model: str
|
||||
strategy: str
|
||||
|
||||
|
||||
@router.post("/update-config")
|
||||
def update_config(body: UpdateConfigBody, response: Response):
|
||||
if (global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading):
|
||||
@router.post("/switch-model")
|
||||
def switch_model(body: SwitchModelBody, response: Response):
|
||||
if global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading:
|
||||
response.status_code = status.HTTP_304_NOT_MODIFIED
|
||||
return
|
||||
|
||||
@@ -33,15 +27,34 @@ def update_config(body: UpdateConfigBody, response: Response):
|
||||
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
||||
try:
|
||||
global_var.set(global_var.Model, RWKV(
|
||||
model=sys.argv[2],
|
||||
strategy=sys.argv[1],
|
||||
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
|
||||
))
|
||||
global_var.set(
|
||||
global_var.Model,
|
||||
RWKV(
|
||||
model=body.model,
|
||||
strategy=body.strategy,
|
||||
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json",
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")
|
||||
|
||||
if global_var.get(global_var.Model_Config) is None:
|
||||
global_var.set(
|
||||
global_var.Model_Config, get_rwkv_config(global_var.get(global_var.Model))
|
||||
)
|
||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
@router.post("/update-config")
|
||||
def update_config(body: ModelConfigBody):
|
||||
"""
|
||||
Will not update the model config immediately, but set it when completion called to avoid modifications during generation
|
||||
"""
|
||||
|
||||
print(body)
|
||||
global_var.set(global_var.Model_Config, body)
|
||||
|
||||
return "success"
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
|
||||
def ngrok_connect():
|
||||
from pyngrok import ngrok, conf
|
||||
|
||||
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
|
||||
ngrok.set_auth_token(os.environ["ngrok_token"])
|
||||
http_tunnel = ngrok.connect(8000)
|
||||
|
||||
@@ -1,5 +1,37 @@
|
||||
from typing import Dict
|
||||
from langchain.llms import RWKV
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelConfigBody(BaseModel):
|
||||
max_tokens: int = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
presence_penalty: float = None
|
||||
frequency_penalty: float = None
|
||||
|
||||
|
||||
def set_rwkv_config(model: RWKV, body: ModelConfigBody):
|
||||
if body.max_tokens:
|
||||
model.max_tokens_per_generation = body.max_tokens
|
||||
if body.temperature:
|
||||
model.temperature = body.temperature
|
||||
if body.top_p:
|
||||
model.top_p = body.top_p
|
||||
if body.presence_penalty:
|
||||
model.penalty_alpha_presence = body.presence_penalty
|
||||
if body.frequency_penalty:
|
||||
model.penalty_alpha_frequency = body.frequency_penalty
|
||||
|
||||
|
||||
def get_rwkv_config(model: RWKV) -> ModelConfigBody:
|
||||
return ModelConfigBody(
|
||||
max_tokens=model.max_tokens_per_generation,
|
||||
temperature=model.temperature,
|
||||
top_p=model.top_p,
|
||||
presence_penalty=model.penalty_alpha_presence,
|
||||
frequency_penalty=model.penalty_alpha_frequency,
|
||||
)
|
||||
|
||||
|
||||
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
|
||||
|
||||
@@ -11,8 +11,8 @@ def set_torch():
|
||||
print("torch already set")
|
||||
else:
|
||||
print("run:")
|
||||
os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep
|
||||
print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}')
|
||||
os.environ["PATH"] = paths + os.pathsep + torch_path + os.pathsep
|
||||
print(f"set Path={paths + os.pathsep + torch_path + os.pathsep}")
|
||||
else:
|
||||
print("torch not found")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user