feat: use model state cache to achieve 5x - 50x faster preparation time for generation
This commit is contained in:
parent
822f2d729c
commit
3e11128c9d
@ -1,3 +1,4 @@
|
|||||||
|
import cyac
|
||||||
import GPUtil
|
import GPUtil
|
||||||
import torch
|
import torch
|
||||||
import rwkv
|
import rwkv
|
||||||
|
@ -11,7 +11,7 @@ import uvicorn
|
|||||||
from utils.rwkv import *
|
from utils.rwkv import *
|
||||||
from utils.torch import *
|
from utils.torch import *
|
||||||
from utils.ngrok import *
|
from utils.ngrok import *
|
||||||
from routes import completion, config
|
from routes import completion, config, state_cache
|
||||||
import global_var
|
import global_var
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
@ -26,11 +26,13 @@ app.add_middleware(
|
|||||||
|
|
||||||
app.include_router(completion.router)
|
app.include_router(completion.router)
|
||||||
app.include_router(config.router)
|
app.include_router(config.router)
|
||||||
|
app.include_router(state_cache.router)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
def init():
|
def init():
|
||||||
global_var.init()
|
global_var.init()
|
||||||
|
state_cache.init()
|
||||||
|
|
||||||
set_torch()
|
set_torch()
|
||||||
|
|
||||||
|
Binary file not shown.
Binary file not shown.
@ -14,7 +14,7 @@ router = APIRouter()
|
|||||||
def get_tokens_path(model_path: str):
|
def get_tokens_path(model_path: str):
|
||||||
model_path = model_path.lower()
|
model_path = model_path.lower()
|
||||||
default_tokens_path = (
|
default_tokens_path = (
|
||||||
f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
|
f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json"
|
||||||
)
|
)
|
||||||
if "raven" in model_path:
|
if "raven" in model_path:
|
||||||
return default_tokens_path
|
return default_tokens_path
|
||||||
|
98
backend-python/routes/state_cache.py
Normal file
98
backend-python/routes/state_cache.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
from fastapi import APIRouter, HTTPException, Response, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import gc
|
||||||
|
import copy
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
trie = None
|
||||||
|
dtrie: Dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
global trie
|
||||||
|
try:
|
||||||
|
import cyac
|
||||||
|
import mmap
|
||||||
|
import os
|
||||||
|
|
||||||
|
if os.path.exists("state_cache.trie"):
|
||||||
|
with open("state_cache.trie", "r") as bf:
|
||||||
|
buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
|
||||||
|
trie = cyac.Trie.from_buff(buff_object, copy=False)
|
||||||
|
else:
|
||||||
|
trie = cyac.Trie()
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
print("cyac not found")
|
||||||
|
|
||||||
|
|
||||||
|
class AddStateBody(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
tokens: list[str]
|
||||||
|
state: Any
|
||||||
|
logits: Any
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/add-state")
|
||||||
|
def add_state(body: AddStateBody):
|
||||||
|
global trie, dtrie
|
||||||
|
if trie is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
|
id = trie.insert(body.prompt)
|
||||||
|
dtrie[id] = {
|
||||||
|
"tokens": copy.deepcopy(body.tokens),
|
||||||
|
"state": copy.deepcopy(body.state),
|
||||||
|
"logits": copy.deepcopy(body.logits),
|
||||||
|
}
|
||||||
|
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/reset-state")
|
||||||
|
def reset_state():
|
||||||
|
global trie
|
||||||
|
if trie is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
|
trie = cyac.Trie()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
|
class LongestPrefixStateBody(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/longest-prefix-state")
|
||||||
|
def longest_prefix_state(body: LongestPrefixStateBody):
|
||||||
|
global trie
|
||||||
|
if trie is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
|
id = -1
|
||||||
|
for id, len in trie.prefix(body.prompt):
|
||||||
|
pass
|
||||||
|
if id != -1:
|
||||||
|
v = dtrie[id]
|
||||||
|
return {
|
||||||
|
"prompt": trie[id],
|
||||||
|
"tokens": v["tokens"],
|
||||||
|
"state": v["state"],
|
||||||
|
"logits": v["logits"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {"prompt": "", "tokens": [], "state": None, "logits": None}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/save-state")
|
||||||
|
def save_state():
|
||||||
|
global trie
|
||||||
|
if trie is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||||
|
|
||||||
|
trie.save("state_cache.trie")
|
||||||
|
|
||||||
|
return "success"
|
@ -1,8 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import copy
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from rwkv_pip.utils import PIPELINE
|
from rwkv_pip.utils import PIPELINE
|
||||||
|
from routes import state_cache
|
||||||
|
|
||||||
|
|
||||||
END_OF_TEXT = 0
|
END_OF_TEXT = 0
|
||||||
@ -61,9 +64,37 @@ class RWKV:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def generate(self, prompt: str, stop: str = None):
|
def generate(self, prompt: str, stop: str = None):
|
||||||
|
cache = None
|
||||||
|
delta_prompt = prompt
|
||||||
|
try:
|
||||||
|
cache = state_cache.longest_prefix_state(
|
||||||
|
state_cache.LongestPrefixStateBody(prompt=prompt)
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
pass
|
||||||
|
if cache is None or cache["prompt"] == "":
|
||||||
self.model_state = None
|
self.model_state = None
|
||||||
self.model_tokens = []
|
self.model_tokens = []
|
||||||
logits = self.run_rnn(self.pipeline.encode(prompt))
|
else:
|
||||||
|
delta_prompt = prompt[len(cache["prompt"]) :]
|
||||||
|
self.model_state = copy.deepcopy(cache["state"])
|
||||||
|
self.model_tokens = copy.deepcopy(cache["tokens"])
|
||||||
|
logits = copy.deepcopy(cache["logits"])
|
||||||
|
|
||||||
|
if delta_prompt != "":
|
||||||
|
logits = self.run_rnn(self.pipeline.encode(delta_prompt))
|
||||||
|
try:
|
||||||
|
state_cache.add_state(
|
||||||
|
state_cache.AddStateBody(
|
||||||
|
prompt=prompt,
|
||||||
|
tokens=self.model_tokens,
|
||||||
|
state=self.model_state,
|
||||||
|
logits=logits,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
pass
|
||||||
|
|
||||||
begin = len(self.model_tokens)
|
begin = len(self.model_tokens)
|
||||||
out_last = begin
|
out_last = begin
|
||||||
|
|
||||||
@ -94,9 +125,32 @@ class RWKV:
|
|||||||
if stop is not None:
|
if stop is not None:
|
||||||
if stop in response:
|
if stop in response:
|
||||||
response = response.split(stop)[0]
|
response = response.split(stop)[0]
|
||||||
|
try:
|
||||||
|
state_cache.add_state(
|
||||||
|
state_cache.AddStateBody(
|
||||||
|
prompt=prompt + response,
|
||||||
|
tokens=self.model_tokens,
|
||||||
|
state=self.model_state,
|
||||||
|
logits=logits,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
pass
|
||||||
yield response, ""
|
yield response, ""
|
||||||
break
|
break
|
||||||
out_last = begin + i + 1
|
out_last = begin + i + 1
|
||||||
|
if i == self.max_tokens_per_generation - 1:
|
||||||
|
try:
|
||||||
|
state_cache.add_state(
|
||||||
|
state_cache.AddStateBody(
|
||||||
|
prompt=prompt + response,
|
||||||
|
tokens=self.model_tokens,
|
||||||
|
state=self.model_state,
|
||||||
|
logits=logits,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
pass
|
||||||
yield response, delta
|
yield response, delta
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user