2023-05-23 11:19:39 +08:00
import os
import pathlib
2023-05-28 23:52:38 +08:00
import copy
2023-05-28 12:53:14 +08:00
from typing import Dict , List
2023-05-28 23:52:38 +08:00
from fastapi import HTTPException
2023-05-30 23:13:27 +08:00
from pydantic import BaseModel , Field
2023-05-28 12:53:14 +08:00
from rwkv_pip . utils import PIPELINE
2023-05-28 23:52:38 +08:00
from routes import state_cache
2023-05-28 12:53:14 +08:00
END_OF_TEXT = 0
END_OF_LINE = 187
2023-05-31 14:55:13 +08:00
END_OF_LINE_DOUBLE = 535
2023-05-28 12:53:14 +08:00
os . environ [ " TORCH_EXTENSIONS_DIR " ] = f " { pathlib . Path ( __file__ ) . parent . parent . resolve ( ) } "
class RWKV :
def __init__ ( self , model : str , strategy : str , tokens_path : str ) - > None :
from rwkv . model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
self . model = Model ( model , strategy )
self . pipeline = PIPELINE ( self . model , tokens_path )
self . model_state = None
self . model_tokens = [ ]
self . CHUNK_LEN = 256
self . max_tokens_per_generation = 500
self . temperature = 1
self . top_p = 0.5
self . penalty_alpha_presence = 0.4
self . penalty_alpha_frequency = 0.4
self . interface = " : "
if " rwkv_vocab " in tokens_path :
2023-05-29 20:51:20 +08:00
self . user = " Question "
self . bot = " Answer "
2023-05-28 12:53:14 +08:00
else :
self . user = " Bob "
self . bot = " Alice "
self . AVOID_REPEAT_TOKENS = [ ]
AVOID_REPEAT = " , : ? ! "
for i in AVOID_REPEAT :
dd = self . pipeline . encode ( i )
assert len ( dd ) == 1
self . AVOID_REPEAT_TOKENS + = dd
2023-05-29 00:08:13 +08:00
self . preload ( )
def preload ( self ) :
2023-05-31 12:46:06 +08:00
interface = self . interface
user = self . user
bot = self . bot
preset_system = (
f """
2023-05-29 00:08:13 +08:00
The following is a coherent verbose detailed conversation between a girl named { bot } and her friend { user } . \
{ bot } is very intelligent , creative and friendly . \
{ bot } is unlikely to disagree with { user } , and { bot } doesn ' t like to ask {user} questions. \
{ bot } likes to tell { user } a lot about herself and her opinions . \
{ bot } usually gives { user } kind , helpful and informative advices . \n
"""
2023-05-31 12:46:06 +08:00
if self . user == " Bob "
else f " { user } { interface } hi \n \n { bot } { interface } Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it. \n \n "
)
logits = self . run_rnn ( self . pipeline . encode ( preset_system ) )
try :
state_cache . add_state (
state_cache . AddStateBody (
prompt = preset_system ,
tokens = self . model_tokens ,
state = self . model_state ,
logits = logits ,
2023-05-29 00:08:13 +08:00
)
2023-05-31 12:46:06 +08:00
)
except HTTPException :
pass
2023-05-29 00:08:13 +08:00
2023-05-31 14:55:13 +08:00
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
def fix_tokens ( tokens ) :
if len ( tokens ) > 0 and tokens [ - 1 ] == END_OF_LINE_DOUBLE :
tokens = tokens [ : - 1 ] + [ END_OF_LINE , END_OF_LINE ]
return tokens
2023-05-28 12:53:14 +08:00
def run_rnn ( self , _tokens : List [ str ] , newline_adj : int = 0 ) :
tokens = [ int ( x ) for x in _tokens ]
self . model_tokens + = tokens
while len ( tokens ) > 0 :
out , self . model_state = self . model . forward (
tokens [ : self . CHUNK_LEN ] , self . model_state
)
tokens = tokens [ self . CHUNK_LEN : ]
out [ END_OF_LINE ] + = newline_adj # adjust \n probability
if self . model_tokens [ - 1 ] in self . AVOID_REPEAT_TOKENS :
out [ self . model_tokens [ - 1 ] ] = - 999999999
return out
def generate ( self , prompt : str , stop : str = None ) :
2023-05-28 23:52:38 +08:00
cache = None
delta_prompt = prompt
try :
cache = state_cache . longest_prefix_state (
state_cache . LongestPrefixStateBody ( prompt = prompt )
)
except HTTPException :
pass
if cache is None or cache [ " prompt " ] == " " :
self . model_state = None
self . model_tokens = [ ]
else :
delta_prompt = prompt [ len ( cache [ " prompt " ] ) : ]
self . model_state = copy . deepcopy ( cache [ " state " ] )
self . model_tokens = copy . deepcopy ( cache [ " tokens " ] )
logits = copy . deepcopy ( cache [ " logits " ] )
if delta_prompt != " " :
2023-05-31 14:55:13 +08:00
logits = self . run_rnn ( self . fix_tokens ( self . pipeline . encode ( delta_prompt ) ) )
2023-05-28 23:52:38 +08:00
try :
state_cache . add_state (
state_cache . AddStateBody (
prompt = prompt ,
tokens = self . model_tokens ,
state = self . model_state ,
logits = logits ,
)
)
except HTTPException :
pass
2023-05-28 12:53:14 +08:00
begin = len ( self . model_tokens )
out_last = begin
occurrence : Dict = { }
response = " "
for i in range ( self . max_tokens_per_generation ) :
for n in occurrence :
logits [ n ] - = (
self . penalty_alpha_presence
+ occurrence [ n ] * self . penalty_alpha_frequency
)
token = self . pipeline . sample_logits (
logits , temperature = self . temperature , top_p = self . top_p
)
if token == END_OF_TEXT :
2023-05-29 20:17:29 +08:00
yield response , " "
2023-05-28 12:53:14 +08:00
break
if token not in occurrence :
occurrence [ token ] = 1
else :
occurrence [ token ] + = 1
logits = self . run_rnn ( [ token ] )
delta : str = self . pipeline . decode ( self . model_tokens [ out_last : ] )
if " \ufffd " not in delta : # avoid utf-8 display issues
response + = delta
if stop is not None :
if stop in response :
response = response . split ( stop ) [ 0 ]
2023-05-28 23:52:38 +08:00
try :
state_cache . add_state (
state_cache . AddStateBody (
prompt = prompt + response ,
tokens = self . model_tokens ,
state = self . model_state ,
logits = logits ,
)
)
except HTTPException :
pass
2023-05-28 12:53:14 +08:00
yield response , " "
break
out_last = begin + i + 1
2023-05-28 23:52:38 +08:00
if i == self . max_tokens_per_generation - 1 :
try :
state_cache . add_state (
state_cache . AddStateBody (
prompt = prompt + response ,
tokens = self . model_tokens ,
state = self . model_state ,
logits = logits ,
)
)
except HTTPException :
pass
2023-05-28 12:53:14 +08:00
yield response , delta
2023-05-17 11:39:00 +08:00
class ModelConfigBody ( BaseModel ) :
2023-05-30 23:13:27 +08:00
max_tokens : int = Field ( default = None , gt = 0 , le = 102400 )
temperature : float = Field ( default = None , ge = 0 , le = 2 )
top_p : float = Field ( default = None , ge = 0 , le = 1 )
presence_penalty : float = Field ( default = None , ge = - 2 , le = 2 )
frequency_penalty : float = Field ( default = None , ge = - 2 , le = 2 )
2023-05-17 11:39:00 +08:00
def set_rwkv_config ( model : RWKV , body : ModelConfigBody ) :
2023-05-30 23:13:27 +08:00
if body . max_tokens is not None :
2023-05-17 11:39:00 +08:00
model . max_tokens_per_generation = body . max_tokens
2023-05-30 23:13:27 +08:00
if body . temperature is not None :
2023-05-17 11:39:00 +08:00
model . temperature = body . temperature
2023-05-30 23:13:27 +08:00
if body . top_p is not None :
2023-05-17 11:39:00 +08:00
model . top_p = body . top_p
2023-05-30 23:13:27 +08:00
if body . presence_penalty is not None :
2023-05-17 11:39:00 +08:00
model . penalty_alpha_presence = body . presence_penalty
2023-05-30 23:13:27 +08:00
if body . frequency_penalty is not None :
2023-05-17 11:39:00 +08:00
model . penalty_alpha_frequency = body . frequency_penalty
def get_rwkv_config ( model : RWKV ) - > ModelConfigBody :
return ModelConfigBody (
max_tokens = model . max_tokens_per_generation ,
temperature = model . temperature ,
top_p = model . top_p ,
presence_penalty = model . penalty_alpha_presence ,
frequency_penalty = model . penalty_alpha_frequency ,
)