From 2ddcd17d23a7548c63a7e698131e6b261a43b11c Mon Sep 17 00:00:00 2001 From: josc146 Date: Fri, 10 May 2024 16:19:21 +0800 Subject: [PATCH] add tps console output --- backend-python/routes/completion.py | 8 ++++++++ backend-python/utils/rwkv.py | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/backend-python/routes/completion.py b/backend-python/routes/completion.py index b72ce68..93254c0 100644 --- a/backend-python/routes/completion.py +++ b/backend-python/routes/completion.py @@ -4,6 +4,7 @@ from threading import Lock from typing import List, Union from enum import Enum import base64 +import time from fastapi import APIRouter, Request, status, HTTPException from sse_starlette.sse import EventSourceResponse @@ -151,10 +152,13 @@ async def eval_rwkv( print(get_rwkv_config(model)) response, prompt_tokens, completion_tokens = "", 0, 0 + completion_start_time = None for response, delta, prompt_tokens, completion_tokens in model.generate( prompt, stop=stop, ): + if not completion_start_time: + completion_start_time = time.time() if await request.is_disconnected(): break if stream: @@ -186,6 +190,10 @@ async def eval_rwkv( ) # torch_gc() requests_num = requests_num - 1 + completion_end_time = time.time() + tps = completion_tokens / (completion_end_time - completion_start_time) + print(f"Generation TPS: {tps:.2f}") + if await request.is_disconnected(): print(f"{request.client} Stop Waiting") quick_log( diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index 9297652..d8e4e9e 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -4,6 +4,7 @@ import os import pathlib import copy import re +import time from typing import Dict, Iterable, List, Tuple, Union, Type, Callable from utils.log import quick_log from fastapi import HTTPException @@ -245,9 +246,13 @@ class AbstractRWKV(ABC): prompt_token_len = 0 if delta_prompt != "": + prompt_start_time = time.time() logits, prompt_token_len = self.run_rnn( self.fix_tokens(self.pipeline.encode(delta_prompt)) ) + prompt_end_time = time.time() + tps = prompt_token_len / (prompt_end_time - prompt_start_time) + print(f"Prompt Prefill TPS: {tps:.2f}", end=" ", flush=True) try: state_cache.add_state( state_cache.AddStateBody(