RWKV-Runner/backend-python/utils/rwkv.py

726 lines
26 KiB
Python
Raw Normal View History

2023-07-25 16:09:31 +08:00
from abc import ABC, abstractmethod
2023-07-31 22:46:13 +08:00
from enum import Enum, auto
import os
import pathlib
import copy
2023-07-25 16:10:22 +08:00
import re
2024-02-02 22:00:01 +08:00
from typing import Dict, Iterable, List, Tuple, Union, Type, Callable
2023-06-12 13:41:51 +08:00
from utils.log import quick_log
from fastapi import HTTPException
2023-05-30 23:13:27 +08:00
from pydantic import BaseModel, Field
from routes import state_cache
2023-08-14 22:07:15 +08:00
import global_var
2023-05-28 12:53:14 +08:00
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
2023-07-31 22:46:13 +08:00
class RWKVType(Enum):
NoneType = auto()
2023-07-31 22:46:13 +08:00
Raven = auto()
World = auto()
Music = auto()
2023-07-25 16:09:31 +08:00
class AbstractRWKV(ABC):
def __init__(self, model, pipeline):
2024-01-05 12:44:44 +08:00
self.EOS_ID = 0
self.name = "rwkv"
self.model = model
self.pipeline = pipeline
2023-05-28 12:53:14 +08:00
self.model_state = None
self.model_tokens = []
self.rwkv_type: RWKVType = RWKVType.NoneType
self.tokenizer_len = len(model.w["emb.weight"])
2023-05-28 12:53:14 +08:00
self.max_tokens_per_generation = 500
self.temperature = 1
2023-07-25 16:09:31 +08:00
self.top_p = 0.3
self.top_k = 0
self.penalty_alpha_presence = 0
self.penalty_alpha_frequency = 1
2024-02-03 22:03:10 +08:00
self.penalty_decay = 0.996
2023-05-28 12:53:14 +08:00
2023-07-25 16:09:31 +08:00
@abstractmethod
def adjust_occurrence(self, occurrence: Dict, token: int):
pass
2023-05-28 12:53:14 +08:00
2023-07-25 16:09:31 +08:00
@abstractmethod
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
pass
2023-05-29 00:08:13 +08:00
2023-05-31 14:55:13 +08:00
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
2023-07-25 16:09:31 +08:00
@abstractmethod
def fix_tokens(self, tokens) -> List[int]:
pass
2023-05-31 14:55:13 +08:00
2023-07-25 16:09:31 +08:00
@abstractmethod
def run_rnn(
self, _tokens: List[str], newline_adj: int = 0
) -> Tuple[List[float], int]:
pass
2023-05-28 12:53:14 +08:00
2023-07-25 16:09:31 +08:00
@abstractmethod
def delta_postprocess(self, delta: str) -> str:
pass
2023-05-28 12:53:14 +08:00
2023-06-20 15:55:52 +08:00
def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]:
import numpy as np
if fast_mode:
2023-07-25 16:09:31 +08:00
embedding, token_len = self.__fast_embedding(
self.fix_tokens(self.pipeline.encode(input)), None
)
else:
self.model_state = None
self.model_tokens = []
2023-06-20 15:55:52 +08:00
_, token_len = self.run_rnn(self.fix_tokens(self.pipeline.encode(input)))
2023-07-25 20:30:43 +08:00
embedding = self.model_state[-11].tolist()
embedding = (embedding / np.linalg.norm(embedding)).tolist()
2023-06-20 15:55:52 +08:00
return embedding, token_len
2023-07-25 16:09:31 +08:00
def __fast_embedding(self, tokens: List[str], state):
2023-07-25 16:14:29 +08:00
import torch
tokens = [int(x) for x in tokens]
2023-06-20 15:55:52 +08:00
token_len = len(tokens)
self = self.model
with torch.no_grad():
w = self.w
args = self.args
if state == None:
state = [None] * args.n_layer * 5
for i in range(
args.n_layer
): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
state[i * 5 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 1] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 2] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 3] = (
torch.zeros(
args.n_embd,
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous()
- 1e30
)
state[i * 5 + 4] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
break
seq_mode = len(tokens) > 1
x = w["emb.weight"][tokens if seq_mode else tokens[0]]
for i in range(args.n_layer):
bbb = f"blocks.{i}."
att = f"blocks.{i}.att."
ffn = f"blocks.{i}.ffn."
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
wtype = dd.wtype
if seq_mode:
if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1":
ATT = (
self.cuda_att_seq
if wtype != torch.uint8
else self.cuda_att_seq_i8
)
else:
ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
else:
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
x = x.to(dtype=atype, device=dev)
kw = w[f"{att}key.weight"]
vw = w[f"{att}value.weight"]
rw = w[f"{att}receptance.weight"]
ow = w[f"{att}output.weight"]
if dd.stream:
kw = kw.to(device=dev, non_blocking=True)
vw = vw.to(device=dev, non_blocking=True)
rw = rw.to(device=dev, non_blocking=True)
ow = ow.to(device=dev, non_blocking=True)
kmx = w[f"{att}key.weight_mx"] if wtype == torch.uint8 else x
krx = w[f"{att}key.weight_rx"] if wtype == torch.uint8 else x
kmy = w[f"{att}key.weight_my"] if wtype == torch.uint8 else x
kry = w[f"{att}key.weight_ry"] if wtype == torch.uint8 else x
vmx = w[f"{att}value.weight_mx"] if wtype == torch.uint8 else x
vrx = w[f"{att}value.weight_rx"] if wtype == torch.uint8 else x
vmy = w[f"{att}value.weight_my"] if wtype == torch.uint8 else x
vry = w[f"{att}value.weight_ry"] if wtype == torch.uint8 else x
rmx = w[f"{att}receptance.weight_mx"] if wtype == torch.uint8 else x
rrx = w[f"{att}receptance.weight_rx"] if wtype == torch.uint8 else x
rmy = w[f"{att}receptance.weight_my"] if wtype == torch.uint8 else x
rry = w[f"{att}receptance.weight_ry"] if wtype == torch.uint8 else x
omx = w[f"{att}output.weight_mx"] if wtype == torch.uint8 else x
orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
) = ATT(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
w[f"{bbb}ln1.weight"],
w[f"{bbb}ln1.bias"],
w[f"{att}time_mix_k"],
w[f"{att}time_mix_v"],
w[f"{att}time_mix_r"],
w[f"{att}time_decay"],
w[f"{att}time_first"],
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
)
2023-06-20 15:55:52 +08:00
return state[0].tolist(), token_len
2023-07-25 16:09:31 +08:00
def generate(
2023-08-14 22:07:15 +08:00
self, prompt: str, stop: Union[str, List[str], None] = None
2023-07-25 16:09:31 +08:00
) -> Iterable[Tuple[str, str, int, int]]:
import numpy as np
2023-06-12 13:41:51 +08:00
quick_log(None, None, "Generation Prompt:\n" + prompt)
cache = None
delta_prompt = prompt
try:
cache = state_cache.longest_prefix_state(
state_cache.LongestPrefixStateBody(prompt=prompt), None
)
except HTTPException:
pass
if cache is None or cache["prompt"] == "" or cache["state"] is None:
self.model_state = None
self.model_tokens = []
else:
delta_prompt = prompt[len(cache["prompt"]) :]
2023-12-28 22:15:31 +08:00
self.model_state = cache["state"]
self.model_tokens = cache["tokens"]
logits = cache["logits"]
2023-06-20 15:55:52 +08:00
prompt_token_len = 0
if delta_prompt != "":
2023-06-20 15:55:52 +08:00
logits, prompt_token_len = self.run_rnn(
self.fix_tokens(self.pipeline.encode(delta_prompt))
)
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
2023-05-28 12:53:14 +08:00
begin = len(self.model_tokens)
out_last = begin
occurrence: Dict = {}
2023-06-20 15:55:52 +08:00
completion_token_len = 0
2023-05-28 12:53:14 +08:00
response = ""
for i in range(self.max_tokens_per_generation):
2023-07-25 16:09:31 +08:00
self.adjust_forward_logits(logits, occurrence, i)
2023-05-28 12:53:14 +08:00
token = self.pipeline.sample_logits(
2023-07-25 16:09:31 +08:00
logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
2023-05-28 12:53:14 +08:00
)
2024-01-05 12:44:44 +08:00
if token == self.EOS_ID:
2024-01-31 21:33:27 +08:00
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
2023-06-20 15:55:52 +08:00
yield response, "", prompt_token_len, completion_token_len
2023-05-28 12:53:14 +08:00
break
2023-07-25 16:09:31 +08:00
self.adjust_occurrence(occurrence, token)
2023-05-28 12:53:14 +08:00
2023-06-20 15:55:52 +08:00
logits, _ = self.run_rnn([token])
completion_token_len = completion_token_len + 1
2023-07-25 16:09:31 +08:00
delta: str = self.delta_postprocess(
self.pipeline.decode(self.model_tokens[out_last:])
)
2023-05-28 12:53:14 +08:00
if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta
if stop is not None:
2023-07-25 16:10:22 +08:00
if type(stop) == str:
if stop in response:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
2023-07-25 16:10:22 +08:00
except HTTPException:
pass
response = response.split(stop)[0]
yield response, "", prompt_token_len, completion_token_len
break
elif type(stop) == list:
stop_exist_regex = "|".join(stop)
matched = re.search(stop_exist_regex, response)
if matched:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
response = response.split(matched.group())[0]
yield response, "", prompt_token_len, completion_token_len
break
2023-05-28 12:53:14 +08:00
out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
2023-06-20 15:55:52 +08:00
yield response, delta, prompt_token_len, completion_token_len
2023-05-17 11:39:00 +08:00
2023-07-25 16:09:31 +08:00
class TextRWKV(AbstractRWKV):
def __init__(self, model, pipeline) -> None:
super().__init__(model, pipeline)
2023-07-25 16:09:31 +08:00
self.CHUNK_LEN = 256
self.max_tokens_per_generation = 500
self.temperature = 1
self.top_p = 0.3
self.top_k = 0
self.penalty_alpha_presence = 0
self.penalty_alpha_frequency = 1
self.interface = ":"
if self.tokenizer_len < 65536:
2023-07-31 22:46:13 +08:00
self.rwkv_type = RWKVType.Raven
2023-07-25 16:09:31 +08:00
self.user = "Bob"
self.bot = "Alice"
self.END_OF_LINE = 187
else:
self.rwkv_type = RWKVType.World
self.user = "User"
self.bot = "Assistant"
self.END_OF_LINE = 11
2023-07-25 16:09:31 +08:00
2024-02-06 12:19:12 +08:00
self.AVOID_REPEAT_TOKENS = set()
2023-07-25 16:09:31 +08:00
AVOID_REPEAT = ""
for i in AVOID_REPEAT:
dd = self.pipeline.encode(i)
assert len(dd) == 1
2024-02-06 12:19:12 +08:00
self.AVOID_REPEAT_TOKENS.add(dd[0])
2024-02-28 23:12:58 +08:00
# self.AVOID_PENALTY_TOKENS = set()
# AVOID_PENALTY = (
# "\n" # \n,.:?!,。:?!"“”<>[]{}/\\|;~`@#$%^&*()_+-=0123456789
# )
# for i in AVOID_PENALTY:
# dd = self.pipeline.encode(i)
# assert len(dd) == 1
# self.AVOID_PENALTY_TOKENS.add(dd[0])
2023-07-25 16:09:31 +08:00
self.__preload()
def adjust_occurrence(self, occurrence: Dict, token: int):
for xxx in occurrence:
2024-02-03 22:03:10 +08:00
occurrence[xxx] *= self.penalty_decay
2023-07-25 16:09:31 +08:00
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
for n in occurrence:
2024-02-28 23:12:58 +08:00
# if n not in self.AVOID_PENALTY_TOKENS:
logits[n] -= (
self.penalty_alpha_presence
+ occurrence[n] * self.penalty_alpha_frequency
)
2023-07-25 16:09:31 +08:00
2023-07-31 22:02:28 +08:00
if i == 0:
for token in self.model_tokens:
token = int(token)
for xxx in occurrence:
2024-02-03 22:03:10 +08:00
occurrence[xxx] *= self.penalty_decay
2023-07-31 22:02:28 +08:00
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
2023-07-25 16:09:31 +08:00
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
def fix_tokens(self, tokens) -> List[int]:
2023-07-31 22:46:13 +08:00
if self.rwkv_type == RWKVType.World:
2023-07-25 16:09:31 +08:00
return tokens
2024-01-05 12:44:44 +08:00
if len(tokens) > 0 and tokens[-1] == 535:
2023-07-25 16:09:31 +08:00
tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
return tokens
def run_rnn(
self, _tokens: List[str], newline_adj: int = 0
) -> Tuple[List[float], int]:
tokens = [int(x) for x in _tokens]
token_len = len(tokens)
self.model_tokens += tokens
while len(tokens) > 0:
out, self.model_state = self.model.forward(
tokens[: self.CHUNK_LEN], self.model_state
)
tokens = tokens[self.CHUNK_LEN :]
out[self.END_OF_LINE] += newline_adj # adjust \n probability
if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS:
out[self.model_tokens[-1]] = -999999999
return out, token_len
def delta_postprocess(self, delta: str) -> str:
return delta
def __preload(self):
interface = self.interface
user = self.user
bot = self.bot
preset_system = (
f"""
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
{bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
2023-07-31 22:46:13 +08:00
if self.rwkv_type == RWKVType.Raven
2023-08-14 22:07:15 +08:00
else (
f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
)
2023-07-25 16:09:31 +08:00
)
logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=preset_system,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
2024-01-05 12:44:44 +08:00
class MusicMidiRWKV(AbstractRWKV):
def __init__(self, model, pipeline):
super().__init__(model, pipeline)
2023-07-25 16:09:31 +08:00
self.max_tokens_per_generation = 500
self.temperature = 1
self.top_p = 0.8
self.top_k = 8
2023-07-31 22:46:13 +08:00
self.rwkv_type = RWKVType.Music
2023-07-25 16:09:31 +08:00
def adjust_occurrence(self, occurrence: Dict, token: int):
for n in occurrence:
occurrence[n] *= 0.997 #### decay repetition penalty
if token >= 128 or token == 127:
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
else:
occurrence[token] = 0.3 + (occurrence[token] if token in occurrence else 0)
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
for n in occurrence:
logits[n] -= 0 + occurrence[n] * 0.5
logits[0] += (i - 2000) / 500 # try not to be too short or too long
logits[127] -= 1 # avoid "t125"
def fix_tokens(self, tokens) -> List[int]:
return tokens
def run_rnn(
self, _tokens: List[str], newline_adj: int = 0
) -> Tuple[List[float], int]:
tokens = [int(x) for x in _tokens]
token_len = len(tokens)
self.model_tokens += tokens
out, self.model_state = self.model.forward(tokens, self.model_state)
return out, token_len
def delta_postprocess(self, delta: str) -> str:
return " " + delta
2024-01-05 12:44:44 +08:00
class MusicAbcRWKV(AbstractRWKV):
def __init__(self, model, pipeline):
super().__init__(model, pipeline)
self.EOS_ID = 3
self.max_tokens_per_generation = 500
self.temperature = 1
self.top_p = 0.8
self.top_k = 8
self.rwkv_type = RWKVType.Music
def adjust_occurrence(self, occurrence: Dict, token: int):
pass
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
pass
def fix_tokens(self, tokens) -> List[int]:
return tokens
def run_rnn(
self, _tokens: List[str], newline_adj: int = 0
) -> Tuple[List[float], int]:
tokens = [int(x) for x in _tokens]
token_len = len(tokens)
self.model_tokens += tokens
out, self.model_state = self.model.forward(tokens, self.model_state)
return out, token_len
def delta_postprocess(self, delta: str) -> str:
return delta
def get_tokenizer(tokenizer_len: int):
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
if tokenizer_len < 2176:
2024-01-05 12:44:44 +08:00
return "abc_tokenizer"
if tokenizer_len < 20096:
return tokenizer_dir + "tokenizer-midipiano.json"
if tokenizer_len < 50277:
return tokenizer_dir + "tokenizer-midi.json"
elif tokenizer_len < 65536:
return tokenizer_dir + "20B_tokenizer.json"
else:
return "rwkv_vocab_v20230424"
2024-02-02 22:00:01 +08:00
def get_model_path(model_path: str) -> str:
if os.path.isabs(model_path):
return model_path
working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd()))
parent_paths: List[pathlib.Path] = [
working_dir, # [cwd](RWKV-Runner)/models/xxx
working_dir.parent, # [cwd](backend-python)/../models/xxx
pathlib.Path(
os.path.abspath(__file__)
).parent.parent, # backend-python/models/xxx
pathlib.Path(
os.path.abspath(__file__)
).parent.parent.parent, # RWKV-Runner/models/xxx
]
child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [
lambda p: p / model_path,
lambda p: p / "build" / "bin" / model_path, # for dev
]
for parent_path in parent_paths:
for child_path in child_paths:
full_path: pathlib.Path = child_path(parent_path)
if os.path.isfile(full_path):
return str(full_path)
return model_path
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
2024-02-02 22:00:01 +08:00
model = get_model_path(model)
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
2023-12-12 20:29:55 +08:00
rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
webgpu = global_var.get(global_var.Args).webgpu
2023-12-04 17:51:21 +08:00
if "midi" in model.lower() or "abc" in model.lower():
os.environ["RWKV_RESCALE_LAYER"] = "999"
# dynamic import to make RWKV_CUDA_ON work
if rwkv_beta:
2023-12-12 20:29:55 +08:00
print("Using rwkv-beta")
from rwkv_pip.beta.model import (
RWKV as Model,
)
2023-12-12 20:29:55 +08:00
elif rwkv_cpp:
print("Using rwkv.cpp, strategy is ignored")
from rwkv_pip.cpp.model import (
RWKV as Model,
)
elif webgpu:
print("Using webgpu")
from rwkv_pip.webgpu.model import (
RWKV as Model,
)
else:
from rwkv_pip.model import (
RWKV as Model,
)
from rwkv_pip.utils import PIPELINE
filename, _ = os.path.splitext(os.path.basename(model))
model = Model(model, strategy)
if not tokenizer:
tokenizer = get_tokenizer(len(model.w["emb.weight"]))
pipeline = PIPELINE(model, tokenizer)
rwkv_map: dict[str, Type[AbstractRWKV]] = {
"20B_tokenizer": TextRWKV,
"rwkv_vocab_v20230424": TextRWKV,
2024-01-05 12:44:44 +08:00
"tokenizer-midi": MusicMidiRWKV,
"tokenizer-midipiano": MusicMidiRWKV,
2024-01-05 12:44:44 +08:00
"abc_tokenizer": MusicAbcRWKV,
}
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
global_var.set(
global_var.Midi_Vocab_Config_Type,
2024-02-04 16:49:46 +08:00
(
global_var.MidiVocabConfig.Piano
if tokenizer_name == "tokenizer-midipiano"
else global_var.MidiVocabConfig.Default
),
)
rwkv: AbstractRWKV
if tokenizer_name in rwkv_map:
rwkv = rwkv_map[tokenizer_name](model, pipeline)
else:
tokenizer_name = tokenizer_name.lower()
if "music" in tokenizer_name or "midi" in tokenizer_name:
rwkv = MusicMidiRWKV(model, pipeline)
elif "abc" in tokenizer_name:
rwkv = MusicAbcRWKV(model, pipeline)
else:
rwkv = TextRWKV(model, pipeline)
rwkv.name = filename
return rwkv
2023-05-17 11:39:00 +08:00
class ModelConfigBody(BaseModel):
2023-05-30 23:13:27 +08:00
max_tokens: int = Field(default=None, gt=0, le=102400)
temperature: float = Field(default=None, ge=0, le=2)
top_p: float = Field(default=None, ge=0, le=1)
presence_penalty: float = Field(default=None, ge=-2, le=2)
frequency_penalty: float = Field(default=None, ge=-2, le=2)
2024-02-03 22:03:10 +08:00
penalty_decay: float = Field(default=None, ge=0.99, le=0.999)
top_k: int = Field(default=None, ge=0, le=25)
2023-05-17 11:39:00 +08:00
model_config = {
"json_schema_extra": {
2023-06-15 21:52:22 +08:00
"example": {
"max_tokens": 1000,
2024-02-03 22:03:10 +08:00
"temperature": 1,
"top_p": 0.3,
"presence_penalty": 0,
"frequency_penalty": 1,
"penalty_decay": 0.996,
2023-06-15 21:52:22 +08:00
}
}
}
2023-06-15 21:52:22 +08:00
2023-05-17 11:39:00 +08:00
2023-07-25 16:09:31 +08:00
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
2023-05-30 23:13:27 +08:00
if body.max_tokens is not None:
2023-05-17 11:39:00 +08:00
model.max_tokens_per_generation = body.max_tokens
2023-05-30 23:13:27 +08:00
if body.temperature is not None:
if body.temperature < 0.1:
model.temperature = 0.1
else:
model.temperature = body.temperature
2023-05-30 23:13:27 +08:00
if body.top_p is not None:
2023-05-17 11:39:00 +08:00
model.top_p = body.top_p
2023-05-30 23:13:27 +08:00
if body.presence_penalty is not None:
2023-05-17 11:39:00 +08:00
model.penalty_alpha_presence = body.presence_penalty
2023-05-30 23:13:27 +08:00
if body.frequency_penalty is not None:
2023-05-17 11:39:00 +08:00
model.penalty_alpha_frequency = body.frequency_penalty
2024-02-03 22:03:10 +08:00
if body.penalty_decay is not None:
model.penalty_decay = body.penalty_decay
if body.top_k is not None:
model.top_k = body.top_k
2023-05-17 11:39:00 +08:00
2023-07-25 16:09:31 +08:00
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
2023-05-17 11:39:00 +08:00
return ModelConfigBody(
max_tokens=model.max_tokens_per_generation,
temperature=model.temperature,
top_p=model.top_p,
presence_penalty=model.penalty_alpha_presence,
frequency_penalty=model.penalty_alpha_frequency,
2024-02-03 22:03:10 +08:00
penalty_decay=model.penalty_decay,
top_k=model.top_k,
2023-05-17 11:39:00 +08:00
)