support for stop
array
This commit is contained in:
parent
05b9b42b56
commit
34095a6c36
@ -25,7 +25,7 @@ class ChatCompletionBody(ModelConfigBody):
|
||||
messages: List[Message]
|
||||
model: str = "rwkv"
|
||||
stream: bool = False
|
||||
stop: str = None
|
||||
stop: str | List[str] = None
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
@ -47,7 +47,7 @@ class CompletionBody(ModelConfigBody):
|
||||
prompt: Union[str, List[str]]
|
||||
model: str = "rwkv"
|
||||
stream: bool = False
|
||||
stop: str = None
|
||||
stop: str | List[str] = None
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
|
@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
|
||||
import os
|
||||
import pathlib
|
||||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
import re
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
from utils.log import quick_log
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
@ -276,21 +277,40 @@ class AbstractRWKV(ABC):
|
||||
if "\ufffd" not in delta: # avoid utf-8 display issues
|
||||
response += delta
|
||||
if stop is not None:
|
||||
if stop in response:
|
||||
try:
|
||||
state_cache.add_state(
|
||||
state_cache.AddStateBody(
|
||||
prompt=prompt + response,
|
||||
tokens=self.model_tokens,
|
||||
state=self.model_state,
|
||||
logits=logits,
|
||||
if type(stop) == str:
|
||||
if stop in response:
|
||||
try:
|
||||
state_cache.add_state(
|
||||
state_cache.AddStateBody(
|
||||
prompt=prompt + response,
|
||||
tokens=self.model_tokens,
|
||||
state=self.model_state,
|
||||
logits=logits,
|
||||
)
|
||||
)
|
||||
)
|
||||
except HTTPException:
|
||||
pass
|
||||
response = response.split(stop)[0]
|
||||
yield response, "", prompt_token_len, completion_token_len
|
||||
break
|
||||
except HTTPException:
|
||||
pass
|
||||
response = response.split(stop)[0]
|
||||
yield response, "", prompt_token_len, completion_token_len
|
||||
break
|
||||
elif type(stop) == list:
|
||||
stop_exist_regex = "|".join(stop)
|
||||
matched = re.search(stop_exist_regex, response)
|
||||
if matched:
|
||||
try:
|
||||
state_cache.add_state(
|
||||
state_cache.AddStateBody(
|
||||
prompt=prompt + response,
|
||||
tokens=self.model_tokens,
|
||||
state=self.model_state,
|
||||
logits=logits,
|
||||
)
|
||||
)
|
||||
except HTTPException:
|
||||
pass
|
||||
response = response.split(matched.group())[0]
|
||||
yield response, "", prompt_token_len, completion_token_len
|
||||
break
|
||||
out_last = begin + i + 1
|
||||
if i == self.max_tokens_per_generation - 1:
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user