support for stop array

This commit is contained in:
josc146 2023-07-25 16:10:22 +08:00
parent 05b9b42b56
commit 34095a6c36
2 changed files with 37 additions and 17 deletions

View File

@ -25,7 +25,7 @@ class ChatCompletionBody(ModelConfigBody):
messages: List[Message]
model: str = "rwkv"
stream: bool = False
stop: str = None
stop: str | List[str] = None
class Config:
schema_extra = {
@ -47,7 +47,7 @@ class CompletionBody(ModelConfigBody):
prompt: Union[str, List[str]]
model: str = "rwkv"
stream: bool = False
stop: str = None
stop: str | List[str] = None
class Config:
schema_extra = {

View File

@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
import os
import pathlib
import copy
from typing import Dict, List, Tuple
import re
from typing import Dict, Iterable, List, Tuple
from utils.log import quick_log
from fastapi import HTTPException
from pydantic import BaseModel, Field
@ -276,21 +277,40 @@ class AbstractRWKV(ABC):
if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta
if stop is not None:
if stop in response:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
if type(stop) == str:
if stop in response:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
)
except HTTPException:
pass
response = response.split(stop)[0]
yield response, "", prompt_token_len, completion_token_len
break
except HTTPException:
pass
response = response.split(stop)[0]
yield response, "", prompt_token_len, completion_token_len
break
elif type(stop) == list:
stop_exist_regex = "|".join(stop)
matched = re.search(stop_exist_regex, response)
if matched:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
response = response.split(matched.group())[0]
yield response, "", prompt_token_len, completion_token_len
break
out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1:
try: