support for stop array

This commit is contained in:
josc146 2023-07-25 16:10:22 +08:00
parent 05b9b42b56
commit 34095a6c36
2 changed files with 37 additions and 17 deletions

View File

@ -25,7 +25,7 @@ class ChatCompletionBody(ModelConfigBody):
messages: List[Message] messages: List[Message]
model: str = "rwkv" model: str = "rwkv"
stream: bool = False stream: bool = False
stop: str = None stop: str | List[str] = None
class Config: class Config:
schema_extra = { schema_extra = {
@ -47,7 +47,7 @@ class CompletionBody(ModelConfigBody):
prompt: Union[str, List[str]] prompt: Union[str, List[str]]
model: str = "rwkv" model: str = "rwkv"
stream: bool = False stream: bool = False
stop: str = None stop: str | List[str] = None
class Config: class Config:
schema_extra = { schema_extra = {

View File

@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
import os import os
import pathlib import pathlib
import copy import copy
from typing import Dict, List, Tuple import re
from typing import Dict, Iterable, List, Tuple
from utils.log import quick_log from utils.log import quick_log
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -276,6 +277,7 @@ class AbstractRWKV(ABC):
if "\ufffd" not in delta: # avoid utf-8 display issues if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta response += delta
if stop is not None: if stop is not None:
if type(stop) == str:
if stop in response: if stop in response:
try: try:
state_cache.add_state( state_cache.add_state(
@ -291,6 +293,24 @@ class AbstractRWKV(ABC):
response = response.split(stop)[0] response = response.split(stop)[0]
yield response, "", prompt_token_len, completion_token_len yield response, "", prompt_token_len, completion_token_len
break break
elif type(stop) == list:
stop_exist_regex = "|".join(stop)
matched = re.search(stop_exist_regex, response)
if matched:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
response = response.split(matched.group())[0]
yield response, "", prompt_token_len, completion_token_len
break
out_last = begin + i + 1 out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1: if i == self.max_tokens_per_generation - 1:
try: try: