add tps console output

This commit is contained in:
josc146
2024-05-10 16:19:21 +08:00
parent 14461930ab
commit 2ddcd17d23
2 changed files with 13 additions and 0 deletions

View File

@@ -4,6 +4,7 @@ import os
import pathlib
import copy
import re
import time
from typing import Dict, Iterable, List, Tuple, Union, Type, Callable
from utils.log import quick_log
from fastapi import HTTPException
@@ -245,9 +246,13 @@ class AbstractRWKV(ABC):
prompt_token_len = 0
if delta_prompt != "":
prompt_start_time = time.time()
logits, prompt_token_len = self.run_rnn(
self.fix_tokens(self.pipeline.encode(delta_prompt))
)
prompt_end_time = time.time()
tps = prompt_token_len / (prompt_end_time - prompt_start_time)
print(f"Prompt Prefill TPS: {tps:.2f}", end=" ", flush=True)
try:
state_cache.add_state(
state_cache.AddStateBody(