add tps console output

This commit is contained in:
josc146 2024-05-10 16:19:21 +08:00
parent 14461930ab
commit 2ddcd17d23
2 changed files with 13 additions and 0 deletions

View File

@ -4,6 +4,7 @@ from threading import Lock
from typing import List, Union
from enum import Enum
import base64
import time
from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse
@ -151,10 +152,13 @@ async def eval_rwkv(
print(get_rwkv_config(model))
response, prompt_tokens, completion_tokens = "", 0, 0
completion_start_time = None
for response, delta, prompt_tokens, completion_tokens in model.generate(
prompt,
stop=stop,
):
if not completion_start_time:
completion_start_time = time.time()
if await request.is_disconnected():
break
if stream:
@ -186,6 +190,10 @@ async def eval_rwkv(
)
# torch_gc()
requests_num = requests_num - 1
completion_end_time = time.time()
tps = completion_tokens / (completion_end_time - completion_start_time)
print(f"Generation TPS: {tps:.2f}")
if await request.is_disconnected():
print(f"{request.client} Stop Waiting")
quick_log(

View File

@ -4,6 +4,7 @@ import os
import pathlib
import copy
import re
import time
from typing import Dict, Iterable, List, Tuple, Union, Type, Callable
from utils.log import quick_log
from fastapi import HTTPException
@ -245,9 +246,13 @@ class AbstractRWKV(ABC):
prompt_token_len = 0
if delta_prompt != "":
prompt_start_time = time.time()
logits, prompt_token_len = self.run_rnn(
self.fix_tokens(self.pipeline.encode(delta_prompt))
)
prompt_end_time = time.time()
tps = prompt_token_len / (prompt_end_time - prompt_start_time)
print(f"Prompt Prefill TPS: {tps:.2f}", end=" ", flush=True)
try:
state_cache.add_state(
state_cache.AddStateBody(