add tps console output
This commit is contained in:
parent
14461930ab
commit
2ddcd17d23
@ -4,6 +4,7 @@ from threading import Lock
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import base64
|
import base64
|
||||||
|
import time
|
||||||
|
|
||||||
from fastapi import APIRouter, Request, status, HTTPException
|
from fastapi import APIRouter, Request, status, HTTPException
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
@ -151,10 +152,13 @@ async def eval_rwkv(
|
|||||||
print(get_rwkv_config(model))
|
print(get_rwkv_config(model))
|
||||||
|
|
||||||
response, prompt_tokens, completion_tokens = "", 0, 0
|
response, prompt_tokens, completion_tokens = "", 0, 0
|
||||||
|
completion_start_time = None
|
||||||
for response, delta, prompt_tokens, completion_tokens in model.generate(
|
for response, delta, prompt_tokens, completion_tokens in model.generate(
|
||||||
prompt,
|
prompt,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
):
|
):
|
||||||
|
if not completion_start_time:
|
||||||
|
completion_start_time = time.time()
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
break
|
break
|
||||||
if stream:
|
if stream:
|
||||||
@ -186,6 +190,10 @@ async def eval_rwkv(
|
|||||||
)
|
)
|
||||||
# torch_gc()
|
# torch_gc()
|
||||||
requests_num = requests_num - 1
|
requests_num = requests_num - 1
|
||||||
|
completion_end_time = time.time()
|
||||||
|
tps = completion_tokens / (completion_end_time - completion_start_time)
|
||||||
|
print(f"Generation TPS: {tps:.2f}")
|
||||||
|
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
print(f"{request.client} Stop Waiting")
|
print(f"{request.client} Stop Waiting")
|
||||||
quick_log(
|
quick_log(
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
import pathlib
|
import pathlib
|
||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from typing import Dict, Iterable, List, Tuple, Union, Type, Callable
|
from typing import Dict, Iterable, List, Tuple, Union, Type, Callable
|
||||||
from utils.log import quick_log
|
from utils.log import quick_log
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
@ -245,9 +246,13 @@ class AbstractRWKV(ABC):
|
|||||||
|
|
||||||
prompt_token_len = 0
|
prompt_token_len = 0
|
||||||
if delta_prompt != "":
|
if delta_prompt != "":
|
||||||
|
prompt_start_time = time.time()
|
||||||
logits, prompt_token_len = self.run_rnn(
|
logits, prompt_token_len = self.run_rnn(
|
||||||
self.fix_tokens(self.pipeline.encode(delta_prompt))
|
self.fix_tokens(self.pipeline.encode(delta_prompt))
|
||||||
)
|
)
|
||||||
|
prompt_end_time = time.time()
|
||||||
|
tps = prompt_token_len / (prompt_end_time - prompt_start_time)
|
||||||
|
print(f"Prompt Prefill TPS: {tps:.2f}", end=" ", flush=True)
|
||||||
try:
|
try:
|
||||||
state_cache.add_state(
|
state_cache.add_state(
|
||||||
state_cache.AddStateBody(
|
state_cache.AddStateBody(
|
||||||
|
Loading…
Reference in New Issue
Block a user