add tps console output

This commit is contained in:
josc146 2024-05-10 16:19:21 +08:00
parent 14461930ab
commit 2ddcd17d23
2 changed files with 13 additions and 0 deletions

View File

@ -4,6 +4,7 @@ from threading import Lock
from typing import List, Union from typing import List, Union
from enum import Enum from enum import Enum
import base64 import base64
import time
from fastapi import APIRouter, Request, status, HTTPException from fastapi import APIRouter, Request, status, HTTPException
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@ -151,10 +152,13 @@ async def eval_rwkv(
print(get_rwkv_config(model)) print(get_rwkv_config(model))
response, prompt_tokens, completion_tokens = "", 0, 0 response, prompt_tokens, completion_tokens = "", 0, 0
completion_start_time = None
for response, delta, prompt_tokens, completion_tokens in model.generate( for response, delta, prompt_tokens, completion_tokens in model.generate(
prompt, prompt,
stop=stop, stop=stop,
): ):
if not completion_start_time:
completion_start_time = time.time()
if await request.is_disconnected(): if await request.is_disconnected():
break break
if stream: if stream:
@ -186,6 +190,10 @@ async def eval_rwkv(
) )
# torch_gc() # torch_gc()
requests_num = requests_num - 1 requests_num = requests_num - 1
completion_end_time = time.time()
tps = completion_tokens / (completion_end_time - completion_start_time)
print(f"Generation TPS: {tps:.2f}")
if await request.is_disconnected(): if await request.is_disconnected():
print(f"{request.client} Stop Waiting") print(f"{request.client} Stop Waiting")
quick_log( quick_log(

View File

@ -4,6 +4,7 @@ import os
import pathlib import pathlib
import copy import copy
import re import re
import time
from typing import Dict, Iterable, List, Tuple, Union, Type, Callable from typing import Dict, Iterable, List, Tuple, Union, Type, Callable
from utils.log import quick_log from utils.log import quick_log
from fastapi import HTTPException from fastapi import HTTPException
@ -245,9 +246,13 @@ class AbstractRWKV(ABC):
prompt_token_len = 0 prompt_token_len = 0
if delta_prompt != "": if delta_prompt != "":
prompt_start_time = time.time()
logits, prompt_token_len = self.run_rnn( logits, prompt_token_len = self.run_rnn(
self.fix_tokens(self.pipeline.encode(delta_prompt)) self.fix_tokens(self.pipeline.encode(delta_prompt))
) )
prompt_end_time = time.time()
tps = prompt_token_len / (prompt_end_time - prompt_start_time)
print(f"Prompt Prefill TPS: {tps:.2f}", end=" ", flush=True)
try: try:
state_cache.add_state( state_cache.add_state(
state_cache.AddStateBody( state_cache.AddStateBody(