feat: use model state cache to achieve 5x - 50x faster preparation time for generation
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import os
|
||||
import pathlib
|
||||
import copy
|
||||
from typing import Dict, List
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from rwkv_pip.utils import PIPELINE
|
||||
from routes import state_cache
|
||||
|
||||
|
||||
END_OF_TEXT = 0
|
||||
@@ -61,9 +64,37 @@ class RWKV:
|
||||
return out
|
||||
|
||||
def generate(self, prompt: str, stop: str = None):
|
||||
self.model_state = None
|
||||
self.model_tokens = []
|
||||
logits = self.run_rnn(self.pipeline.encode(prompt))
|
||||
cache = None
|
||||
delta_prompt = prompt
|
||||
try:
|
||||
cache = state_cache.longest_prefix_state(
|
||||
state_cache.LongestPrefixStateBody(prompt=prompt)
|
||||
)
|
||||
except HTTPException:
|
||||
pass
|
||||
if cache is None or cache["prompt"] == "":
|
||||
self.model_state = None
|
||||
self.model_tokens = []
|
||||
else:
|
||||
delta_prompt = prompt[len(cache["prompt"]) :]
|
||||
self.model_state = copy.deepcopy(cache["state"])
|
||||
self.model_tokens = copy.deepcopy(cache["tokens"])
|
||||
logits = copy.deepcopy(cache["logits"])
|
||||
|
||||
if delta_prompt != "":
|
||||
logits = self.run_rnn(self.pipeline.encode(delta_prompt))
|
||||
try:
|
||||
state_cache.add_state(
|
||||
state_cache.AddStateBody(
|
||||
prompt=prompt,
|
||||
tokens=self.model_tokens,
|
||||
state=self.model_state,
|
||||
logits=logits,
|
||||
)
|
||||
)
|
||||
except HTTPException:
|
||||
pass
|
||||
|
||||
begin = len(self.model_tokens)
|
||||
out_last = begin
|
||||
|
||||
@@ -94,9 +125,32 @@ class RWKV:
|
||||
if stop is not None:
|
||||
if stop in response:
|
||||
response = response.split(stop)[0]
|
||||
try:
|
||||
state_cache.add_state(
|
||||
state_cache.AddStateBody(
|
||||
prompt=prompt + response,
|
||||
tokens=self.model_tokens,
|
||||
state=self.model_state,
|
||||
logits=logits,
|
||||
)
|
||||
)
|
||||
except HTTPException:
|
||||
pass
|
||||
yield response, ""
|
||||
break
|
||||
out_last = begin + i + 1
|
||||
if i == self.max_tokens_per_generation - 1:
|
||||
try:
|
||||
state_cache.add_state(
|
||||
state_cache.AddStateBody(
|
||||
prompt=prompt + response,
|
||||
tokens=self.model_tokens,
|
||||
state=self.model_state,
|
||||
logits=logits,
|
||||
)
|
||||
)
|
||||
except HTTPException:
|
||||
pass
|
||||
yield response, delta
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user