feat: use model state cache to achieve 5x - 50x faster preparation time for generation
This commit is contained in:
		
							parent
							
								
									822f2d729c
								
							
						
					
					
						commit
						3e11128c9d
					
				@ -1,3 +1,4 @@
 | 
			
		||||
import cyac
 | 
			
		||||
import GPUtil
 | 
			
		||||
import torch
 | 
			
		||||
import rwkv
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ import uvicorn
 | 
			
		||||
from utils.rwkv import *
 | 
			
		||||
from utils.torch import *
 | 
			
		||||
from utils.ngrok import *
 | 
			
		||||
from routes import completion, config
 | 
			
		||||
from routes import completion, config, state_cache
 | 
			
		||||
import global_var
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
@ -26,11 +26,13 @@ app.add_middleware(
 | 
			
		||||
 | 
			
		||||
app.include_router(completion.router)
 | 
			
		||||
app.include_router(config.router)
 | 
			
		||||
app.include_router(state_cache.router)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.on_event("startup")
 | 
			
		||||
def init():
 | 
			
		||||
    global_var.init()
 | 
			
		||||
    state_cache.init()
 | 
			
		||||
 | 
			
		||||
    set_torch()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							@ -14,7 +14,7 @@ router = APIRouter()
 | 
			
		||||
def get_tokens_path(model_path: str):
 | 
			
		||||
    model_path = model_path.lower()
 | 
			
		||||
    default_tokens_path = (
 | 
			
		||||
        f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
 | 
			
		||||
        f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json"
 | 
			
		||||
    )
 | 
			
		||||
    if "raven" in model_path:
 | 
			
		||||
        return default_tokens_path
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										98
									
								
								backend-python/routes/state_cache.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								backend-python/routes/state_cache.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,98 @@
 | 
			
		||||
from typing import Any, Dict
 | 
			
		||||
from fastapi import APIRouter, HTTPException, Response, status
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
import gc
 | 
			
		||||
import copy
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
trie = None
 | 
			
		||||
dtrie: Dict = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init():
 | 
			
		||||
    global trie
 | 
			
		||||
    try:
 | 
			
		||||
        import cyac
 | 
			
		||||
        import mmap
 | 
			
		||||
        import os
 | 
			
		||||
 | 
			
		||||
        if os.path.exists("state_cache.trie"):
 | 
			
		||||
            with open("state_cache.trie", "r") as bf:
 | 
			
		||||
                buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
 | 
			
		||||
            trie = cyac.Trie.from_buff(buff_object, copy=False)
 | 
			
		||||
        else:
 | 
			
		||||
            trie = cyac.Trie()
 | 
			
		||||
    except ModuleNotFoundError:
 | 
			
		||||
        print("cyac not found")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AddStateBody(BaseModel):
 | 
			
		||||
    prompt: str
 | 
			
		||||
    tokens: list[str]
 | 
			
		||||
    state: Any
 | 
			
		||||
    logits: Any
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/add-state")
 | 
			
		||||
def add_state(body: AddStateBody):
 | 
			
		||||
    global trie, dtrie
 | 
			
		||||
    if trie is None:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
			
		||||
 | 
			
		||||
    id = trie.insert(body.prompt)
 | 
			
		||||
    dtrie[id] = {
 | 
			
		||||
        "tokens": copy.deepcopy(body.tokens),
 | 
			
		||||
        "state": copy.deepcopy(body.state),
 | 
			
		||||
        "logits": copy.deepcopy(body.logits),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return "success"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/reset-state")
 | 
			
		||||
def reset_state():
 | 
			
		||||
    global trie
 | 
			
		||||
    if trie is None:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
			
		||||
 | 
			
		||||
    trie = cyac.Trie()
 | 
			
		||||
    gc.collect()
 | 
			
		||||
 | 
			
		||||
    return "success"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LongestPrefixStateBody(BaseModel):
 | 
			
		||||
    prompt: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/longest-prefix-state")
 | 
			
		||||
def longest_prefix_state(body: LongestPrefixStateBody):
 | 
			
		||||
    global trie
 | 
			
		||||
    if trie is None:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
			
		||||
 | 
			
		||||
    id = -1
 | 
			
		||||
    for id, len in trie.prefix(body.prompt):
 | 
			
		||||
        pass
 | 
			
		||||
    if id != -1:
 | 
			
		||||
        v = dtrie[id]
 | 
			
		||||
        return {
 | 
			
		||||
            "prompt": trie[id],
 | 
			
		||||
            "tokens": v["tokens"],
 | 
			
		||||
            "state": v["state"],
 | 
			
		||||
            "logits": v["logits"],
 | 
			
		||||
        }
 | 
			
		||||
    else:
 | 
			
		||||
        return {"prompt": "", "tokens": [], "state": None, "logits": None}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/save-state")
 | 
			
		||||
def save_state():
 | 
			
		||||
    global trie
 | 
			
		||||
    if trie is None:
 | 
			
		||||
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
			
		||||
 | 
			
		||||
    trie.save("state_cache.trie")
 | 
			
		||||
 | 
			
		||||
    return "success"
 | 
			
		||||
@ -1,8 +1,11 @@
 | 
			
		||||
import os
 | 
			
		||||
import pathlib
 | 
			
		||||
import copy
 | 
			
		||||
from typing import Dict, List
 | 
			
		||||
from fastapi import HTTPException
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
from rwkv_pip.utils import PIPELINE
 | 
			
		||||
from routes import state_cache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
END_OF_TEXT = 0
 | 
			
		||||
@ -61,9 +64,37 @@ class RWKV:
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def generate(self, prompt: str, stop: str = None):
 | 
			
		||||
        self.model_state = None
 | 
			
		||||
        self.model_tokens = []
 | 
			
		||||
        logits = self.run_rnn(self.pipeline.encode(prompt))
 | 
			
		||||
        cache = None
 | 
			
		||||
        delta_prompt = prompt
 | 
			
		||||
        try:
 | 
			
		||||
            cache = state_cache.longest_prefix_state(
 | 
			
		||||
                state_cache.LongestPrefixStateBody(prompt=prompt)
 | 
			
		||||
            )
 | 
			
		||||
        except HTTPException:
 | 
			
		||||
            pass
 | 
			
		||||
        if cache is None or cache["prompt"] == "":
 | 
			
		||||
            self.model_state = None
 | 
			
		||||
            self.model_tokens = []
 | 
			
		||||
        else:
 | 
			
		||||
            delta_prompt = prompt[len(cache["prompt"]) :]
 | 
			
		||||
            self.model_state = copy.deepcopy(cache["state"])
 | 
			
		||||
            self.model_tokens = copy.deepcopy(cache["tokens"])
 | 
			
		||||
            logits = copy.deepcopy(cache["logits"])
 | 
			
		||||
 | 
			
		||||
        if delta_prompt != "":
 | 
			
		||||
            logits = self.run_rnn(self.pipeline.encode(delta_prompt))
 | 
			
		||||
            try:
 | 
			
		||||
                state_cache.add_state(
 | 
			
		||||
                    state_cache.AddStateBody(
 | 
			
		||||
                        prompt=prompt,
 | 
			
		||||
                        tokens=self.model_tokens,
 | 
			
		||||
                        state=self.model_state,
 | 
			
		||||
                        logits=logits,
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            except HTTPException:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        begin = len(self.model_tokens)
 | 
			
		||||
        out_last = begin
 | 
			
		||||
 | 
			
		||||
@ -94,9 +125,32 @@ class RWKV:
 | 
			
		||||
                if stop is not None:
 | 
			
		||||
                    if stop in response:
 | 
			
		||||
                        response = response.split(stop)[0]
 | 
			
		||||
                        try:
 | 
			
		||||
                            state_cache.add_state(
 | 
			
		||||
                                state_cache.AddStateBody(
 | 
			
		||||
                                    prompt=prompt + response,
 | 
			
		||||
                                    tokens=self.model_tokens,
 | 
			
		||||
                                    state=self.model_state,
 | 
			
		||||
                                    logits=logits,
 | 
			
		||||
                                )
 | 
			
		||||
                            )
 | 
			
		||||
                        except HTTPException:
 | 
			
		||||
                            pass
 | 
			
		||||
                        yield response, ""
 | 
			
		||||
                        break
 | 
			
		||||
                out_last = begin + i + 1
 | 
			
		||||
                if i == self.max_tokens_per_generation - 1:
 | 
			
		||||
                    try:
 | 
			
		||||
                        state_cache.add_state(
 | 
			
		||||
                            state_cache.AddStateBody(
 | 
			
		||||
                                prompt=prompt + response,
 | 
			
		||||
                                tokens=self.model_tokens,
 | 
			
		||||
                                state=self.model_state,
 | 
			
		||||
                                logits=logits,
 | 
			
		||||
                            )
 | 
			
		||||
                        )
 | 
			
		||||
                    except HTTPException:
 | 
			
		||||
                        pass
 | 
			
		||||
                yield response, delta
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user