diff --git a/backend-python/routes/completion.py b/backend-python/routes/completion.py index 36b436d..0fd1b39 100644 --- a/backend-python/routes/completion.py +++ b/backend-python/routes/completion.py @@ -25,7 +25,7 @@ class ChatCompletionBody(ModelConfigBody): messages: List[Message] model: str = "rwkv" stream: bool = False - stop: str = None + stop: str | List[str] = None class Config: schema_extra = { @@ -47,7 +47,7 @@ class CompletionBody(ModelConfigBody): prompt: Union[str, List[str]] model: str = "rwkv" stream: bool = False - stop: str = None + stop: str | List[str] = None class Config: schema_extra = { diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index e39126c..ca54e1c 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -2,7 +2,8 @@ from abc import ABC, abstractmethod import os import pathlib import copy -from typing import Dict, List, Tuple +import re +from typing import Dict, Iterable, List, Tuple from utils.log import quick_log from fastapi import HTTPException from pydantic import BaseModel, Field @@ -276,21 +277,40 @@ class AbstractRWKV(ABC): if "\ufffd" not in delta: # avoid utf-8 display issues response += delta if stop is not None: - if stop in response: - try: - state_cache.add_state( - state_cache.AddStateBody( - prompt=prompt + response, - tokens=self.model_tokens, - state=self.model_state, - logits=logits, + if type(stop) == str: + if stop in response: + try: + state_cache.add_state( + state_cache.AddStateBody( + prompt=prompt + response, + tokens=self.model_tokens, + state=self.model_state, + logits=logits, + ) ) - ) - except HTTPException: - pass - response = response.split(stop)[0] - yield response, "", prompt_token_len, completion_token_len - break + except HTTPException: + pass + response = response.split(stop)[0] + yield response, "", prompt_token_len, completion_token_len + break + elif type(stop) == list: + stop_exist_regex = "|".join(stop) + matched = re.search(stop_exist_regex, response) + if matched: + try: + state_cache.add_state( + state_cache.AddStateBody( + prompt=prompt + response, + tokens=self.model_tokens, + state=self.model_state, + logits=logits, + ) + ) + except HTTPException: + pass + response = response.split(matched.group())[0] + yield response, "", prompt_token_len, completion_token_len + break out_last = begin + i + 1 if i == self.max_tokens_per_generation - 1: try: