fix a tps error

This commit is contained in:
josc146
2024-05-16 13:48:06 +08:00
parent e1c12202aa
commit b24a18cd3a
2 changed files with 8 additions and 2 deletions

View File

@@ -257,7 +257,10 @@ class AbstractRWKV(ABC):
self.fix_tokens(self.pipeline.encode(delta_prompt))
)
prompt_end_time = time.time()
tps = prompt_token_len / (prompt_end_time - prompt_start_time)
prompt_interval = prompt_end_time - prompt_start_time
tps = 0
if prompt_interval > 0:
tps = prompt_token_len / prompt_interval
print(f"Prompt Prefill TPS: {tps:.2f}", end=" ", flush=True)
try:
state_cache.add_state(