backend api
This commit is contained in:
		
							parent
							
								
									4795514e8f
								
							
						
					
					
						commit
						0e852daf43
					
				
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -1,6 +1,7 @@
 | 
			
		||||
build/bin
 | 
			
		||||
node_modules
 | 
			
		||||
frontend/dist
 | 
			
		||||
__pycache__
 | 
			
		||||
.idea
 | 
			
		||||
.vs
 | 
			
		||||
package.json.md5
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										27
									
								
								backend-python/global_var.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								backend-python/global_var.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,27 @@
 | 
			
		||||
from enum import Enum, auto
 | 
			
		||||
 | 
			
		||||
Model = 'model'
 | 
			
		||||
Model_Status = 'model_status'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelStatus(Enum):
 | 
			
		||||
    Offline = auto()
 | 
			
		||||
    Loading = auto()
 | 
			
		||||
    Working = auto()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init():
 | 
			
		||||
    global GLOBALS
 | 
			
		||||
    GLOBALS = {}
 | 
			
		||||
    set(Model_Status, ModelStatus.Offline)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set(key, value):
 | 
			
		||||
    GLOBALS[key] = value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get(key):
 | 
			
		||||
    if key in GLOBALS:
 | 
			
		||||
        return GLOBALS[key]
 | 
			
		||||
    else:
 | 
			
		||||
        return None
 | 
			
		||||
@ -1,42 +1,15 @@
 | 
			
		||||
import json
 | 
			
		||||
import pathlib
 | 
			
		||||
import sys
 | 
			
		||||
from typing import List
 | 
			
		||||
import os
 | 
			
		||||
import sysconfig
 | 
			
		||||
import psutil
 | 
			
		||||
 | 
			
		||||
from fastapi import FastAPI, Request, status, HTTPException
 | 
			
		||||
from langchain.llms import RWKV
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from sse_starlette.sse import EventSourceResponse
 | 
			
		||||
from fastapi import FastAPI
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
import uvicorn
 | 
			
		||||
 | 
			
		||||
from rwkv_helper import rwkv_generate
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_torch():
 | 
			
		||||
    torch_path = os.path.join(sysconfig.get_paths()["purelib"], "torch\\lib")
 | 
			
		||||
    paths = os.environ.get("PATH", "")
 | 
			
		||||
    if os.path.exists(torch_path):
 | 
			
		||||
        print(f"torch found: {torch_path}")
 | 
			
		||||
        if torch_path in paths:
 | 
			
		||||
            print("torch already set")
 | 
			
		||||
        else:
 | 
			
		||||
            print("run:")
 | 
			
		||||
            os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep
 | 
			
		||||
            print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}')
 | 
			
		||||
    else:
 | 
			
		||||
        print("torch not found")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def torch_gc():
 | 
			
		||||
    import torch
 | 
			
		||||
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        with torch.cuda.device(0):
 | 
			
		||||
            torch.cuda.empty_cache()
 | 
			
		||||
            torch.cuda.ipc_collect()
 | 
			
		||||
from utils.rwkv import *
 | 
			
		||||
from utils.torch import *
 | 
			
		||||
from utils.ngrok import *
 | 
			
		||||
from routes import completion, config
 | 
			
		||||
import global_var
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
@ -49,87 +22,33 @@ app.add_middleware(
 | 
			
		||||
    allow_headers=["*"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
app.include_router(completion.router)
 | 
			
		||||
app.include_router(config.router)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.on_event('startup')
 | 
			
		||||
def init():
 | 
			
		||||
    global model
 | 
			
		||||
    global_var.init()
 | 
			
		||||
 | 
			
		||||
    set_torch()
 | 
			
		||||
 | 
			
		||||
    model = RWKV(
 | 
			
		||||
        model=sys.argv[2],
 | 
			
		||||
        strategy=sys.argv[1],
 | 
			
		||||
        tokens_path=f"{pathlib.Path(__file__).parent.resolve()}/20B_tokenizer.json"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if os.environ.get("ngrok_token") is not None:
 | 
			
		||||
        ngrok_connect()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ngrok_connect():
 | 
			
		||||
    from pyngrok import ngrok, conf
 | 
			
		||||
    conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
 | 
			
		||||
    ngrok.set_auth_token(os.environ["ngrok_token"])
 | 
			
		||||
    http_tunnel = ngrok.connect(8000)
 | 
			
		||||
    print(http_tunnel.public_url)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Message(BaseModel):
 | 
			
		||||
    role: str
 | 
			
		||||
    content: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Body(BaseModel):
 | 
			
		||||
    messages: List[Message]
 | 
			
		||||
    model: str
 | 
			
		||||
    stream: bool
 | 
			
		||||
    max_tokens: int
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/")
 | 
			
		||||
def read_root():
 | 
			
		||||
    return {"Hello": "World!"}
 | 
			
		||||
 | 
			
		||||
@app.post("update-config")
 | 
			
		||||
def updateConfig(body: Body):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/v1/chat/completions")
 | 
			
		||||
@app.post("/chat/completions")
 | 
			
		||||
async def completions(body: Body, request: Request):
 | 
			
		||||
    global model
 | 
			
		||||
 | 
			
		||||
    question = body.messages[-1]
 | 
			
		||||
    if question.role == 'user':
 | 
			
		||||
        question = question.content
 | 
			
		||||
    else:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")
 | 
			
		||||
 | 
			
		||||
    completion_text = ""
 | 
			
		||||
    for message in body.messages:
 | 
			
		||||
        if message.role == 'user':
 | 
			
		||||
            completion_text += "Bob: " + message.content + "\n\n"
 | 
			
		||||
        elif message.role == 'assistant':
 | 
			
		||||
            completion_text += "Alice: " + message.content + "\n\n"
 | 
			
		||||
    completion_text += "Alice:"
 | 
			
		||||
 | 
			
		||||
    async def eval_rwkv():
 | 
			
		||||
        if body.stream:
 | 
			
		||||
            for response, delta in rwkv_generate(model, completion_text):
 | 
			
		||||
                if await request.is_disconnected():
 | 
			
		||||
                    break
 | 
			
		||||
                yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
 | 
			
		||||
            yield "[DONE]"
 | 
			
		||||
        else:
 | 
			
		||||
            response = None
 | 
			
		||||
            for response, delta in rwkv_generate(model, completion_text):
 | 
			
		||||
                pass
 | 
			
		||||
            yield json.dumps({"response": response, "model": "rwkv"})
 | 
			
		||||
        # torch_gc()
 | 
			
		||||
 | 
			
		||||
    return EventSourceResponse(eval_rwkv())
 | 
			
		||||
@app.post("/exit")
 | 
			
		||||
def read_root():
 | 
			
		||||
    parent_pid = os.getpid()
 | 
			
		||||
    parent = psutil.Process(parent_pid)
 | 
			
		||||
    for child in parent.children(recursive=True):
 | 
			
		||||
        child.kill()
 | 
			
		||||
    parent.kill()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    uvicorn.run("main:app", reload=False, app_dir="backend-python")
 | 
			
		||||
    uvicorn.run("main:app", port=8000)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										58
									
								
								backend-python/routes/completion.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								backend-python/routes/completion.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,58 @@
 | 
			
		||||
import json
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
from fastapi import APIRouter, Request, status, HTTPException
 | 
			
		||||
from sse_starlette.sse import EventSourceResponse
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from utils.rwkv import *
 | 
			
		||||
import global_var
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Message(BaseModel):
 | 
			
		||||
    role: str
 | 
			
		||||
    content: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CompletionBody(BaseModel):
 | 
			
		||||
    messages: List[Message]
 | 
			
		||||
    model: str
 | 
			
		||||
    stream: bool
 | 
			
		||||
    max_tokens: int
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/v1/chat/completions")
 | 
			
		||||
@router.post("/chat/completions")
 | 
			
		||||
async def completions(body: CompletionBody, request: Request):
 | 
			
		||||
    model = global_var.get(global_var.Model)
 | 
			
		||||
 | 
			
		||||
    question = body.messages[-1]
 | 
			
		||||
    if question.role == 'user':
 | 
			
		||||
        question = question.content
 | 
			
		||||
    else:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
 | 
			
		||||
 | 
			
		||||
    completion_text = ""
 | 
			
		||||
    for message in body.messages:
 | 
			
		||||
        if message.role == 'user':
 | 
			
		||||
            completion_text += "Bob: " + message.content + "\n\n"
 | 
			
		||||
        elif message.role == 'assistant':
 | 
			
		||||
            completion_text += "Alice: " + message.content + "\n\n"
 | 
			
		||||
    completion_text += "Alice:"
 | 
			
		||||
 | 
			
		||||
    async def eval_rwkv():
 | 
			
		||||
        if body.stream:
 | 
			
		||||
            for response, delta in rwkv_generate(model, completion_text):
 | 
			
		||||
                if await request.is_disconnected():
 | 
			
		||||
                    break
 | 
			
		||||
                yield json.dumps({"response": response, "choices": [{"delta": {"content": delta}}], "model": "rwkv"})
 | 
			
		||||
            yield "[DONE]"
 | 
			
		||||
        else:
 | 
			
		||||
            response = None
 | 
			
		||||
            for response, delta in rwkv_generate(model, completion_text):
 | 
			
		||||
                pass
 | 
			
		||||
            yield json.dumps({"response": response, "model": "rwkv"})
 | 
			
		||||
        # torch_gc()
 | 
			
		||||
 | 
			
		||||
    return EventSourceResponse(eval_rwkv())
 | 
			
		||||
							
								
								
									
										46
									
								
								backend-python/routes/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								backend-python/routes/config.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,46 @@
 | 
			
		||||
import pathlib
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
from fastapi import APIRouter, HTTPException, status
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from langchain.llms import RWKV
 | 
			
		||||
from utils.rwkv import *
 | 
			
		||||
from utils.torch import *
 | 
			
		||||
import global_var
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UpdateConfigBody(BaseModel):
 | 
			
		||||
    model: str = None
 | 
			
		||||
    strategy: str = None
 | 
			
		||||
    max_response_token: int = None
 | 
			
		||||
    temperature: float = None
 | 
			
		||||
    top_p: float = None
 | 
			
		||||
    presence_penalty: float = None
 | 
			
		||||
    count_penalty: float = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/update-config")
 | 
			
		||||
def update_config(body: UpdateConfigBody):
 | 
			
		||||
    if (global_var.get(global_var.Model_Status) is global_var.ModelStatus.Loading):
 | 
			
		||||
        return "loading"
 | 
			
		||||
 | 
			
		||||
    global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
 | 
			
		||||
    global_var.set(global_var.Model, None)
 | 
			
		||||
    torch_gc()
 | 
			
		||||
 | 
			
		||||
    global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
 | 
			
		||||
    try:
 | 
			
		||||
        global_var.set(global_var.Model, RWKV(
 | 
			
		||||
            model=sys.argv[2],
 | 
			
		||||
            strategy=sys.argv[1],
 | 
			
		||||
            tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
 | 
			
		||||
        ))
 | 
			
		||||
    except Exception:
 | 
			
		||||
        global_var.set(global_var.Model_Status, global_var.ModelStatus.Offline)
 | 
			
		||||
        raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "failed to load")
 | 
			
		||||
 | 
			
		||||
    global_var.set(global_var.Model_Status, global_var.ModelStatus.Working)
 | 
			
		||||
 | 
			
		||||
    return "success"
 | 
			
		||||
							
								
								
									
										9
									
								
								backend-python/utils/ngrok.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								backend-python/utils/ngrok.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,9 @@
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ngrok_connect():
 | 
			
		||||
    from pyngrok import ngrok, conf
 | 
			
		||||
    conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
 | 
			
		||||
    ngrok.set_auth_token(os.environ["ngrok_token"])
 | 
			
		||||
    http_tunnel = ngrok.connect(8000)
 | 
			
		||||
    print(http_tunnel.public_url)
 | 
			
		||||
@ -15,8 +15,8 @@ def rwkv_generate(model: RWKV, prompt: str):
 | 
			
		||||
    for i in range(model.max_tokens_per_generation):
 | 
			
		||||
        for n in occurrence:
 | 
			
		||||
            logits[n] -= (
 | 
			
		||||
                    model.penalty_alpha_presence
 | 
			
		||||
                    + occurrence[n] * model.penalty_alpha_frequency
 | 
			
		||||
                model.penalty_alpha_presence
 | 
			
		||||
                + occurrence[n] * model.penalty_alpha_frequency
 | 
			
		||||
            )
 | 
			
		||||
        token = model.pipeline.sample_logits(
 | 
			
		||||
            logits, temperature=model.temperature, top_p=model.top_p
 | 
			
		||||
							
								
								
									
										26
									
								
								backend-python/utils/torch.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								backend-python/utils/torch.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,26 @@
 | 
			
		||||
import os
 | 
			
		||||
import sysconfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_torch():
 | 
			
		||||
    torch_path = os.path.join(sysconfig.get_paths()["purelib"], "torch\\lib")
 | 
			
		||||
    paths = os.environ.get("PATH", "")
 | 
			
		||||
    if os.path.exists(torch_path):
 | 
			
		||||
        print(f"torch found: {torch_path}")
 | 
			
		||||
        if torch_path in paths:
 | 
			
		||||
            print("torch already set")
 | 
			
		||||
        else:
 | 
			
		||||
            print("run:")
 | 
			
		||||
            os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep
 | 
			
		||||
            print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}')
 | 
			
		||||
    else:
 | 
			
		||||
        print("torch not found")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def torch_gc():
 | 
			
		||||
    import torch
 | 
			
		||||
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        with torch.cuda.device(0):
 | 
			
		||||
            torch.cuda.empty_cache()
 | 
			
		||||
            torch.cuda.ipc_collect()
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user