improve state cache performance

This commit is contained in:
josc146
2023-12-28 22:15:31 +08:00
parent e083f2c629
commit 7f3cfd54b0
2 changed files with 14 additions and 28 deletions

View File

@@ -239,14 +239,9 @@ class AbstractRWKV(ABC):
self.model_tokens = []
else:
delta_prompt = prompt[len(cache["prompt"]) :]
state = cache["state"]
self.model_state = (
copy.deepcopy(state)
if type(state) == list or type(state) == np.ndarray
else state
)
self.model_tokens = copy.deepcopy(cache["tokens"])
logits = copy.deepcopy(cache["logits"])
self.model_state = cache["state"]
self.model_tokens = cache["tokens"]
logits = cache["logits"]
prompt_token_len = 0
if delta_prompt != "":