feat: use model state cache to achieve 5x - 50x faster preparation time for generation

This commit is contained in:
josc146
2023-05-28 23:52:38 +08:00
parent 822f2d729c
commit 3e11128c9d
7 changed files with 160 additions and 5 deletions

View File

@@ -1,8 +1,11 @@
import os
import pathlib
import copy
from typing import Dict, List
from fastapi import HTTPException
from pydantic import BaseModel
from rwkv_pip.utils import PIPELINE
from routes import state_cache
END_OF_TEXT = 0
@@ -61,9 +64,37 @@ class RWKV:
return out
def generate(self, prompt: str, stop: str = None):
self.model_state = None
self.model_tokens = []
logits = self.run_rnn(self.pipeline.encode(prompt))
cache = None
delta_prompt = prompt
try:
cache = state_cache.longest_prefix_state(
state_cache.LongestPrefixStateBody(prompt=prompt)
)
except HTTPException:
pass
if cache is None or cache["prompt"] == "":
self.model_state = None
self.model_tokens = []
else:
delta_prompt = prompt[len(cache["prompt"]) :]
self.model_state = copy.deepcopy(cache["state"])
self.model_tokens = copy.deepcopy(cache["tokens"])
logits = copy.deepcopy(cache["logits"])
if delta_prompt != "":
logits = self.run_rnn(self.pipeline.encode(delta_prompt))
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
begin = len(self.model_tokens)
out_last = begin
@@ -94,9 +125,32 @@ class RWKV:
if stop is not None:
if stop in response:
response = response.split(stop)[0]
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
yield response, ""
break
out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
yield response, delta