support for stop
array
This commit is contained in:
parent
05b9b42b56
commit
34095a6c36
@ -25,7 +25,7 @@ class ChatCompletionBody(ModelConfigBody):
|
|||||||
messages: List[Message]
|
messages: List[Message]
|
||||||
model: str = "rwkv"
|
model: str = "rwkv"
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
stop: str = None
|
stop: str | List[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
@ -47,7 +47,7 @@ class CompletionBody(ModelConfigBody):
|
|||||||
prompt: Union[str, List[str]]
|
prompt: Union[str, List[str]]
|
||||||
model: str = "rwkv"
|
model: str = "rwkv"
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
stop: str = None
|
stop: str | List[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
|
@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import copy
|
import copy
|
||||||
from typing import Dict, List, Tuple
|
import re
|
||||||
|
from typing import Dict, Iterable, List, Tuple
|
||||||
from utils.log import quick_log
|
from utils.log import quick_log
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@ -276,21 +277,40 @@ class AbstractRWKV(ABC):
|
|||||||
if "\ufffd" not in delta: # avoid utf-8 display issues
|
if "\ufffd" not in delta: # avoid utf-8 display issues
|
||||||
response += delta
|
response += delta
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
if stop in response:
|
if type(stop) == str:
|
||||||
try:
|
if stop in response:
|
||||||
state_cache.add_state(
|
try:
|
||||||
state_cache.AddStateBody(
|
state_cache.add_state(
|
||||||
prompt=prompt + response,
|
state_cache.AddStateBody(
|
||||||
tokens=self.model_tokens,
|
prompt=prompt + response,
|
||||||
state=self.model_state,
|
tokens=self.model_tokens,
|
||||||
logits=logits,
|
state=self.model_state,
|
||||||
|
logits=logits,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
except HTTPException:
|
||||||
except HTTPException:
|
pass
|
||||||
pass
|
response = response.split(stop)[0]
|
||||||
response = response.split(stop)[0]
|
yield response, "", prompt_token_len, completion_token_len
|
||||||
yield response, "", prompt_token_len, completion_token_len
|
break
|
||||||
break
|
elif type(stop) == list:
|
||||||
|
stop_exist_regex = "|".join(stop)
|
||||||
|
matched = re.search(stop_exist_regex, response)
|
||||||
|
if matched:
|
||||||
|
try:
|
||||||
|
state_cache.add_state(
|
||||||
|
state_cache.AddStateBody(
|
||||||
|
prompt=prompt + response,
|
||||||
|
tokens=self.model_tokens,
|
||||||
|
state=self.model_state,
|
||||||
|
logits=logits,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
pass
|
||||||
|
response = response.split(matched.group())[0]
|
||||||
|
yield response, "", prompt_token_len, completion_token_len
|
||||||
|
break
|
||||||
out_last = begin + i + 1
|
out_last = begin + i + 1
|
||||||
if i == self.max_tokens_per_generation - 1:
|
if i == self.max_tokens_per_generation - 1:
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user