feat: use model state cache to achieve 5x - 50x faster preparation time for generation
This commit is contained in:
		
							parent
							
								
									822f2d729c
								
							
						
					
					
						commit
						3e11128c9d
					
				@ -1,3 +1,4 @@
 | 
				
			|||||||
 | 
					import cyac
 | 
				
			||||||
import GPUtil
 | 
					import GPUtil
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import rwkv
 | 
					import rwkv
 | 
				
			||||||
 | 
				
			|||||||
@ -11,7 +11,7 @@ import uvicorn
 | 
				
			|||||||
from utils.rwkv import *
 | 
					from utils.rwkv import *
 | 
				
			||||||
from utils.torch import *
 | 
					from utils.torch import *
 | 
				
			||||||
from utils.ngrok import *
 | 
					from utils.ngrok import *
 | 
				
			||||||
from routes import completion, config
 | 
					from routes import completion, config, state_cache
 | 
				
			||||||
import global_var
 | 
					import global_var
 | 
				
			||||||
 | 
					
 | 
				
			||||||
app = FastAPI()
 | 
					app = FastAPI()
 | 
				
			||||||
@ -26,11 +26,13 @@ app.add_middleware(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
app.include_router(completion.router)
 | 
					app.include_router(completion.router)
 | 
				
			||||||
app.include_router(config.router)
 | 
					app.include_router(config.router)
 | 
				
			||||||
 | 
					app.include_router(state_cache.router)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.on_event("startup")
 | 
					@app.on_event("startup")
 | 
				
			||||||
def init():
 | 
					def init():
 | 
				
			||||||
    global_var.init()
 | 
					    global_var.init()
 | 
				
			||||||
 | 
					    state_cache.init()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    set_torch()
 | 
					    set_torch()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							@ -14,7 +14,7 @@ router = APIRouter()
 | 
				
			|||||||
def get_tokens_path(model_path: str):
 | 
					def get_tokens_path(model_path: str):
 | 
				
			||||||
    model_path = model_path.lower()
 | 
					    model_path = model_path.lower()
 | 
				
			||||||
    default_tokens_path = (
 | 
					    default_tokens_path = (
 | 
				
			||||||
        f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
 | 
					        f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/20B_tokenizer.json"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    if "raven" in model_path:
 | 
					    if "raven" in model_path:
 | 
				
			||||||
        return default_tokens_path
 | 
					        return default_tokens_path
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										98
									
								
								backend-python/routes/state_cache.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								backend-python/routes/state_cache.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,98 @@
 | 
				
			|||||||
 | 
					from typing import Any, Dict
 | 
				
			||||||
 | 
					from fastapi import APIRouter, HTTPException, Response, status
 | 
				
			||||||
 | 
					from pydantic import BaseModel
 | 
				
			||||||
 | 
					import gc
 | 
				
			||||||
 | 
					import copy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					router = APIRouter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					trie = None
 | 
				
			||||||
 | 
					dtrie: Dict = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def init():
 | 
				
			||||||
 | 
					    global trie
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        import cyac
 | 
				
			||||||
 | 
					        import mmap
 | 
				
			||||||
 | 
					        import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if os.path.exists("state_cache.trie"):
 | 
				
			||||||
 | 
					            with open("state_cache.trie", "r") as bf:
 | 
				
			||||||
 | 
					                buff_object = mmap.mmap(bf.fileno(), 0, access=mmap.ACCESS_READ)
 | 
				
			||||||
 | 
					            trie = cyac.Trie.from_buff(buff_object, copy=False)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            trie = cyac.Trie()
 | 
				
			||||||
 | 
					    except ModuleNotFoundError:
 | 
				
			||||||
 | 
					        print("cyac not found")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AddStateBody(BaseModel):
 | 
				
			||||||
 | 
					    prompt: str
 | 
				
			||||||
 | 
					    tokens: list[str]
 | 
				
			||||||
 | 
					    state: Any
 | 
				
			||||||
 | 
					    logits: Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@router.post("/add-state")
 | 
				
			||||||
 | 
					def add_state(body: AddStateBody):
 | 
				
			||||||
 | 
					    global trie, dtrie
 | 
				
			||||||
 | 
					    if trie is None:
 | 
				
			||||||
 | 
					        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    id = trie.insert(body.prompt)
 | 
				
			||||||
 | 
					    dtrie[id] = {
 | 
				
			||||||
 | 
					        "tokens": copy.deepcopy(body.tokens),
 | 
				
			||||||
 | 
					        "state": copy.deepcopy(body.state),
 | 
				
			||||||
 | 
					        "logits": copy.deepcopy(body.logits),
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return "success"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@router.post("/reset-state")
 | 
				
			||||||
 | 
					def reset_state():
 | 
				
			||||||
 | 
					    global trie
 | 
				
			||||||
 | 
					    if trie is None:
 | 
				
			||||||
 | 
					        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    trie = cyac.Trie()
 | 
				
			||||||
 | 
					    gc.collect()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return "success"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LongestPrefixStateBody(BaseModel):
 | 
				
			||||||
 | 
					    prompt: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@router.post("/longest-prefix-state")
 | 
				
			||||||
 | 
					def longest_prefix_state(body: LongestPrefixStateBody):
 | 
				
			||||||
 | 
					    global trie
 | 
				
			||||||
 | 
					    if trie is None:
 | 
				
			||||||
 | 
					        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    id = -1
 | 
				
			||||||
 | 
					    for id, len in trie.prefix(body.prompt):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					    if id != -1:
 | 
				
			||||||
 | 
					        v = dtrie[id]
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            "prompt": trie[id],
 | 
				
			||||||
 | 
					            "tokens": v["tokens"],
 | 
				
			||||||
 | 
					            "state": v["state"],
 | 
				
			||||||
 | 
					            "logits": v["logits"],
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return {"prompt": "", "tokens": [], "state": None, "logits": None}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@router.post("/save-state")
 | 
				
			||||||
 | 
					def save_state():
 | 
				
			||||||
 | 
					    global trie
 | 
				
			||||||
 | 
					    if trie is None:
 | 
				
			||||||
 | 
					        raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    trie.save("state_cache.trie")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return "success"
 | 
				
			||||||
@ -1,8 +1,11 @@
 | 
				
			|||||||
import os
 | 
					import os
 | 
				
			||||||
import pathlib
 | 
					import pathlib
 | 
				
			||||||
 | 
					import copy
 | 
				
			||||||
from typing import Dict, List
 | 
					from typing import Dict, List
 | 
				
			||||||
 | 
					from fastapi import HTTPException
 | 
				
			||||||
from pydantic import BaseModel
 | 
					from pydantic import BaseModel
 | 
				
			||||||
from rwkv_pip.utils import PIPELINE
 | 
					from rwkv_pip.utils import PIPELINE
 | 
				
			||||||
 | 
					from routes import state_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
END_OF_TEXT = 0
 | 
					END_OF_TEXT = 0
 | 
				
			||||||
@ -61,9 +64,37 @@ class RWKV:
 | 
				
			|||||||
        return out
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def generate(self, prompt: str, stop: str = None):
 | 
					    def generate(self, prompt: str, stop: str = None):
 | 
				
			||||||
 | 
					        cache = None
 | 
				
			||||||
 | 
					        delta_prompt = prompt
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            cache = state_cache.longest_prefix_state(
 | 
				
			||||||
 | 
					                state_cache.LongestPrefixStateBody(prompt=prompt)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        except HTTPException:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        if cache is None or cache["prompt"] == "":
 | 
				
			||||||
            self.model_state = None
 | 
					            self.model_state = None
 | 
				
			||||||
            self.model_tokens = []
 | 
					            self.model_tokens = []
 | 
				
			||||||
        logits = self.run_rnn(self.pipeline.encode(prompt))
 | 
					        else:
 | 
				
			||||||
 | 
					            delta_prompt = prompt[len(cache["prompt"]) :]
 | 
				
			||||||
 | 
					            self.model_state = copy.deepcopy(cache["state"])
 | 
				
			||||||
 | 
					            self.model_tokens = copy.deepcopy(cache["tokens"])
 | 
				
			||||||
 | 
					            logits = copy.deepcopy(cache["logits"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if delta_prompt != "":
 | 
				
			||||||
 | 
					            logits = self.run_rnn(self.pipeline.encode(delta_prompt))
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                state_cache.add_state(
 | 
				
			||||||
 | 
					                    state_cache.AddStateBody(
 | 
				
			||||||
 | 
					                        prompt=prompt,
 | 
				
			||||||
 | 
					                        tokens=self.model_tokens,
 | 
				
			||||||
 | 
					                        state=self.model_state,
 | 
				
			||||||
 | 
					                        logits=logits,
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            except HTTPException:
 | 
				
			||||||
 | 
					                pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        begin = len(self.model_tokens)
 | 
					        begin = len(self.model_tokens)
 | 
				
			||||||
        out_last = begin
 | 
					        out_last = begin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -94,9 +125,32 @@ class RWKV:
 | 
				
			|||||||
                if stop is not None:
 | 
					                if stop is not None:
 | 
				
			||||||
                    if stop in response:
 | 
					                    if stop in response:
 | 
				
			||||||
                        response = response.split(stop)[0]
 | 
					                        response = response.split(stop)[0]
 | 
				
			||||||
 | 
					                        try:
 | 
				
			||||||
 | 
					                            state_cache.add_state(
 | 
				
			||||||
 | 
					                                state_cache.AddStateBody(
 | 
				
			||||||
 | 
					                                    prompt=prompt + response,
 | 
				
			||||||
 | 
					                                    tokens=self.model_tokens,
 | 
				
			||||||
 | 
					                                    state=self.model_state,
 | 
				
			||||||
 | 
					                                    logits=logits,
 | 
				
			||||||
 | 
					                                )
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					                        except HTTPException:
 | 
				
			||||||
 | 
					                            pass
 | 
				
			||||||
                        yield response, ""
 | 
					                        yield response, ""
 | 
				
			||||||
                        break
 | 
					                        break
 | 
				
			||||||
                out_last = begin + i + 1
 | 
					                out_last = begin + i + 1
 | 
				
			||||||
 | 
					                if i == self.max_tokens_per_generation - 1:
 | 
				
			||||||
 | 
					                    try:
 | 
				
			||||||
 | 
					                        state_cache.add_state(
 | 
				
			||||||
 | 
					                            state_cache.AddStateBody(
 | 
				
			||||||
 | 
					                                prompt=prompt + response,
 | 
				
			||||||
 | 
					                                tokens=self.model_tokens,
 | 
				
			||||||
 | 
					                                state=self.model_state,
 | 
				
			||||||
 | 
					                                logits=logits,
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    except HTTPException:
 | 
				
			||||||
 | 
					                        pass
 | 
				
			||||||
                yield response, delta
 | 
					                yield response, delta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user