RWKV-Runner/backend-python/utils/rwkv.py

740 lines
26 KiB
Python
Raw Normal View History

2023-07-25 16:09:31 +08:00
from abc import ABC, abstractmethod
2023-07-31 22:46:13 +08:00
from enum import Enum, auto
import os
import pathlib
import copy
2023-07-25 16:10:22 +08:00
import re
2024-05-10 16:19:21 +08:00
import time
2024-02-02 22:00:01 +08:00
from typing import Dict, Iterable, List, Tuple, Union, Type, Callable
2023-06-12 13:41:51 +08:00
from utils.log import quick_log
from fastapi import HTTPException
2023-05-30 23:13:27 +08:00
from pydantic import BaseModel, Field
from routes import state_cache
2023-08-14 22:07:15 +08:00
import global_var
2023-05-28 12:53:14 +08:00
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
2023-07-31 22:46:13 +08:00
class RWKVType(Enum):
NoneType = auto()
2023-07-31 22:46:13 +08:00
Raven = auto()
World = auto()
Music = auto()
2023-07-25 16:09:31 +08:00
class AbstractRWKV(ABC):
def __init__(self, model, pipeline):
2024-01-05 12:44:44 +08:00
self.EOS_ID = 0
self.name = "rwkv"
2024-03-24 22:29:28 +08:00
self.version = 4
self.model = model
self.pipeline = pipeline
2023-05-28 12:53:14 +08:00
self.model_state = None
self.model_tokens = []
self.rwkv_type: RWKVType = RWKVType.NoneType
self.tokenizer_len = len(model.w["emb.weight"])
2023-05-28 12:53:14 +08:00
self.max_tokens_per_generation = 500
self.temperature = 1
2023-07-25 16:09:31 +08:00
self.top_p = 0.3
self.top_k = 0
self.penalty_alpha_presence = 0
self.penalty_alpha_frequency = 1
2024-02-03 22:03:10 +08:00
self.penalty_decay = 0.996
2024-03-02 17:50:41 +08:00
self.global_penalty = False
2023-05-28 12:53:14 +08:00
2023-07-25 16:09:31 +08:00
@abstractmethod
def adjust_occurrence(self, occurrence: Dict, token: int):
pass
2023-05-28 12:53:14 +08:00
2023-07-25 16:09:31 +08:00
@abstractmethod
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
pass
2023-05-29 00:08:13 +08:00
2023-05-31 14:55:13 +08:00
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
2023-07-25 16:09:31 +08:00
@abstractmethod
def fix_tokens(self, tokens) -> List[int]:
pass
2023-05-31 14:55:13 +08:00
2023-07-25 16:09:31 +08:00
@abstractmethod
def run_rnn(
self, _tokens: List[str], newline_adj: int = 0
) -> Tuple[List[float], int]:
pass
2023-05-28 12:53:14 +08:00
2023-07-25 16:09:31 +08:00
@abstractmethod
def delta_postprocess(self, delta: str) -> str:
pass
2023-05-28 12:53:14 +08:00
2023-06-20 15:55:52 +08:00
def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]:
import numpy as np
if fast_mode:
2023-07-25 16:09:31 +08:00
embedding, token_len = self.__fast_embedding(
self.fix_tokens(self.pipeline.encode(input)), None
)
else:
self.model_state = None
self.model_tokens = []
2023-06-20 15:55:52 +08:00
_, token_len = self.run_rnn(self.fix_tokens(self.pipeline.encode(input)))
2023-07-25 20:30:43 +08:00
embedding = self.model_state[-11].tolist()
embedding = (embedding / np.linalg.norm(embedding)).tolist()
2023-06-20 15:55:52 +08:00
return embedding, token_len
2023-07-25 16:09:31 +08:00
def __fast_embedding(self, tokens: List[str], state):
2023-07-25 16:14:29 +08:00
import torch
tokens = [int(x) for x in tokens]
2023-06-20 15:55:52 +08:00
token_len = len(tokens)
self = self.model
with torch.no_grad():
w = self.w
args = self.args
if state == None:
state = [None] * args.n_layer * 5
for i in range(
args.n_layer
): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
state[i * 5 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 1] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 2] = torch.zeros(
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
).contiguous()
state[i * 5 + 3] = (
torch.zeros(
args.n_embd,
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous()
- 1e30
)
state[i * 5 + 4] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
break
seq_mode = len(tokens) > 1
x = w["emb.weight"][tokens if seq_mode else tokens[0]]
for i in range(args.n_layer):
bbb = f"blocks.{i}."
att = f"blocks.{i}.att."
ffn = f"blocks.{i}.ffn."
dd = self.strategy[i]
dev = dd.device
atype = dd.atype
wtype = dd.wtype
if seq_mode:
if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1":
ATT = (
self.cuda_att_seq
if wtype != torch.uint8
else self.cuda_att_seq_i8
)
else:
ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
else:
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
x = x.to(dtype=atype, device=dev)
kw = w[f"{att}key.weight"]
vw = w[f"{att}value.weight"]
rw = w[f"{att}receptance.weight"]
ow = w[f"{att}output.weight"]
if dd.stream:
kw = kw.to(device=dev, non_blocking=True)
vw = vw.to(device=dev, non_blocking=True)
rw = rw.to(device=dev, non_blocking=True)
ow = ow.to(device=dev, non_blocking=True)
kmx = w[f"{att}key.weight_mx"] if wtype == torch.uint8 else x
krx = w[f"{att}key.weight_rx"] if wtype == torch.uint8 else x
kmy = w[f"{att}key.weight_my"] if wtype == torch.uint8 else x
kry = w[f"{att}key.weight_ry"] if wtype == torch.uint8 else x
vmx = w[f"{att}value.weight_mx"] if wtype == torch.uint8 else x
vrx = w[f"{att}value.weight_rx"] if wtype == torch.uint8 else x
vmy = w[f"{att}value.weight_my"] if wtype == torch.uint8 else x
vry = w[f"{att}value.weight_ry"] if wtype == torch.uint8 else x
rmx = w[f"{att}receptance.weight_mx"] if wtype == torch.uint8 else x
rrx = w[f"{att}receptance.weight_rx"] if wtype == torch.uint8 else x
rmy = w[f"{att}receptance.weight_my"] if wtype == torch.uint8 else x
rry = w[f"{att}receptance.weight_ry"] if wtype == torch.uint8 else x
omx = w[f"{att}output.weight_mx"] if wtype == torch.uint8 else x
orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
) = ATT(
x,
state[i * 5 + 0],
state[i * 5 + 1],
state[i * 5 + 2],
state[i * 5 + 3],
w[f"{bbb}ln1.weight"],
w[f"{bbb}ln1.bias"],
w[f"{att}time_mix_k"],
w[f"{att}time_mix_v"],
w[f"{att}time_mix_r"],
w[f"{att}time_decay"],
w[f"{att}time_first"],
kw,
vw,
rw,
ow,
kmx,
krx,
kmy,
kry,
vmx,
vrx,
vmy,
vry,
rmx,
rrx,
rmy,
rry,
omx,
orx,
omy,
ory,
)
2023-06-20 15:55:52 +08:00
return state[0].tolist(), token_len
2023-07-25 16:09:31 +08:00
def generate(
2023-08-14 22:07:15 +08:00
self, prompt: str, stop: Union[str, List[str], None] = None
2023-07-25 16:09:31 +08:00
) -> Iterable[Tuple[str, str, int, int]]:
import numpy as np
2023-06-12 13:41:51 +08:00
quick_log(None, None, "Generation Prompt:\n" + prompt)
cache = None
delta_prompt = prompt
try:
cache = state_cache.longest_prefix_state(
state_cache.LongestPrefixStateBody(prompt=prompt), None
)
except HTTPException:
pass
if cache is None or cache["prompt"] == "" or cache["state"] is None:
self.model_state = None
self.model_tokens = []
else:
delta_prompt = prompt[len(cache["prompt"]) :]
2023-12-28 22:15:31 +08:00
self.model_state = cache["state"]
self.model_tokens = cache["tokens"]
logits = cache["logits"]
2023-06-20 15:55:52 +08:00
prompt_token_len = 0
if delta_prompt != "":
2024-05-10 16:19:21 +08:00
prompt_start_time = time.time()
2023-06-20 15:55:52 +08:00
logits, prompt_token_len = self.run_rnn(
self.fix_tokens(self.pipeline.encode(delta_prompt))
)
2024-05-10 16:19:21 +08:00
prompt_end_time = time.time()
tps = prompt_token_len / (prompt_end_time - prompt_start_time)
print(f"Prompt Prefill TPS: {tps:.2f}", end=" ", flush=True)
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
2023-05-28 12:53:14 +08:00
begin = len(self.model_tokens)
out_last = begin
occurrence: Dict = {}
2023-06-20 15:55:52 +08:00
completion_token_len = 0
2023-05-28 12:53:14 +08:00
response = ""
for i in range(self.max_tokens_per_generation):
2023-07-25 16:09:31 +08:00
self.adjust_forward_logits(logits, occurrence, i)
2023-05-28 12:53:14 +08:00
token = self.pipeline.sample_logits(
2023-07-25 16:09:31 +08:00
logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
2023-05-28 12:53:14 +08:00
)
2024-01-05 12:44:44 +08:00
if token == self.EOS_ID:
2024-01-31 21:33:27 +08:00
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
2023-06-20 15:55:52 +08:00
yield response, "", prompt_token_len, completion_token_len
2023-05-28 12:53:14 +08:00
break
2023-07-25 16:09:31 +08:00
self.adjust_occurrence(occurrence, token)
2023-05-28 12:53:14 +08:00
2023-06-20 15:55:52 +08:00
logits, _ = self.run_rnn([token])
completion_token_len = completion_token_len + 1
2023-07-25 16:09:31 +08:00
delta: str = self.delta_postprocess(
self.pipeline.decode(self.model_tokens[out_last:])
)
2023-05-28 12:53:14 +08:00
if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta
if stop is not None:
2023-07-25 16:10:22 +08:00
if type(stop) == str:
if stop in response:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
2023-07-25 16:10:22 +08:00
except HTTPException:
pass
response = response.split(stop)[0]
yield response, "", prompt_token_len, completion_token_len
break
elif type(stop) == list:
exit_flag = False
for s in stop:
if s in response:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
2023-07-25 16:10:22 +08:00
)
except HTTPException:
pass
exit_flag = True
response = response.split(s)[0]
yield response, "", prompt_token_len, completion_token_len
break
if exit_flag:
2023-07-25 16:10:22 +08:00
break
2023-05-28 12:53:14 +08:00
out_last = begin + i + 1
if i == self.max_tokens_per_generation - 1:
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=prompt + response,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
2023-06-20 15:55:52 +08:00
yield response, delta, prompt_token_len, completion_token_len
2023-05-17 11:39:00 +08:00
2023-07-25 16:09:31 +08:00
class TextRWKV(AbstractRWKV):
def __init__(self, model, pipeline) -> None:
super().__init__(model, pipeline)
2023-07-25 16:09:31 +08:00
self.CHUNK_LEN = 256
self.max_tokens_per_generation = 500
self.temperature = 1
self.top_p = 0.3
self.top_k = 0
self.penalty_alpha_presence = 0
self.penalty_alpha_frequency = 1
self.interface = ":"
if self.tokenizer_len < 65536:
2023-07-31 22:46:13 +08:00
self.rwkv_type = RWKVType.Raven
2023-07-25 16:09:31 +08:00
self.user = "Bob"
self.bot = "Alice"
self.END_OF_LINE = 187
else:
self.rwkv_type = RWKVType.World
self.user = "User"
self.bot = "Assistant"
self.END_OF_LINE = 11
2023-07-25 16:09:31 +08:00
2024-02-06 12:19:12 +08:00
self.AVOID_REPEAT_TOKENS = set()
2023-07-25 16:09:31 +08:00
AVOID_REPEAT = ""
for i in AVOID_REPEAT:
dd = self.pipeline.encode(i)
assert len(dd) == 1
2024-02-06 12:19:12 +08:00
self.AVOID_REPEAT_TOKENS.add(dd[0])
2024-02-29 17:54:33 +08:00
self.AVOID_PENALTY_TOKENS = set()
AVOID_PENALTY = '\n,.:?!,。:?!"“”<>[]{}/\\|;~`@#$%^&*()_+-=0123456789 '
for i in AVOID_PENALTY:
dd = self.pipeline.encode(i)
if len(dd) == 1:
self.AVOID_PENALTY_TOKENS.add(dd[0])
2023-07-25 16:09:31 +08:00
self.__preload()
def adjust_occurrence(self, occurrence: Dict, token: int):
for xxx in occurrence:
2024-02-03 22:03:10 +08:00
occurrence[xxx] *= self.penalty_decay
2023-07-25 16:09:31 +08:00
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
for n in occurrence:
2024-02-28 23:12:58 +08:00
# if n not in self.AVOID_PENALTY_TOKENS:
logits[n] -= (
self.penalty_alpha_presence
+ occurrence[n] * self.penalty_alpha_frequency
)
2023-07-25 16:09:31 +08:00
2024-03-02 17:50:41 +08:00
# set global_penalty to False to get the same generated results as the official RWKV Gradio
if self.global_penalty and i == 0:
2023-07-31 22:02:28 +08:00
for token in self.model_tokens:
token = int(token)
2024-02-29 17:54:33 +08:00
if token not in self.AVOID_PENALTY_TOKENS:
2024-03-01 13:18:03 +08:00
self.adjust_occurrence(occurrence, token)
2023-07-31 22:02:28 +08:00
2023-07-25 16:09:31 +08:00
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
def fix_tokens(self, tokens) -> List[int]:
2023-07-31 22:46:13 +08:00
if self.rwkv_type == RWKVType.World:
2023-07-25 16:09:31 +08:00
return tokens
2024-01-05 12:44:44 +08:00
if len(tokens) > 0 and tokens[-1] == 535:
2023-07-25 16:09:31 +08:00
tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
return tokens
def run_rnn(
self, _tokens: List[str], newline_adj: int = 0
) -> Tuple[List[float], int]:
tokens = [int(x) for x in _tokens]
token_len = len(tokens)
self.model_tokens += tokens
while len(tokens) > 0:
out, self.model_state = self.model.forward(
tokens[: self.CHUNK_LEN], self.model_state
)
tokens = tokens[self.CHUNK_LEN :]
out[self.END_OF_LINE] += newline_adj # adjust \n probability
if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS:
out[self.model_tokens[-1]] = -999999999
return out, token_len
def delta_postprocess(self, delta: str) -> str:
return delta
def __preload(self):
interface = self.interface
user = self.user
bot = self.bot
preset_system = (
f"""
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
{bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
2023-07-31 22:46:13 +08:00
if self.rwkv_type == RWKVType.Raven
2023-08-14 22:07:15 +08:00
else (
f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
)
2023-07-25 16:09:31 +08:00
)
logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
try:
state_cache.add_state(
state_cache.AddStateBody(
prompt=preset_system,
tokens=self.model_tokens,
state=self.model_state,
logits=logits,
)
)
except HTTPException:
pass
2024-01-05 12:44:44 +08:00
class MusicMidiRWKV(AbstractRWKV):
def __init__(self, model, pipeline):
super().__init__(model, pipeline)
2023-07-25 16:09:31 +08:00
self.max_tokens_per_generation = 500
self.temperature = 1
self.top_p = 0.8
self.top_k = 8
2023-07-31 22:46:13 +08:00
self.rwkv_type = RWKVType.Music
2023-07-25 16:09:31 +08:00
def adjust_occurrence(self, occurrence: Dict, token: int):
for n in occurrence:
occurrence[n] *= 0.997 #### decay repetition penalty
if token >= 128 or token == 127:
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
else:
occurrence[token] = 0.3 + (occurrence[token] if token in occurrence else 0)
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
for n in occurrence:
logits[n] -= 0 + occurrence[n] * 0.5
logits[0] += (i - 2000) / 500 # try not to be too short or too long
logits[127] -= 1 # avoid "t125"
def fix_tokens(self, tokens) -> List[int]:
return tokens
def run_rnn(
self, _tokens: List[str], newline_adj: int = 0
) -> Tuple[List[float], int]:
tokens = [int(x) for x in _tokens]
token_len = len(tokens)
self.model_tokens += tokens
out, self.model_state = self.model.forward(tokens, self.model_state)
return out, token_len
def delta_postprocess(self, delta: str) -> str:
return " " + delta
2024-01-05 12:44:44 +08:00
class MusicAbcRWKV(AbstractRWKV):
def __init__(self, model, pipeline):
super().__init__(model, pipeline)
self.EOS_ID = 3
self.max_tokens_per_generation = 500
self.temperature = 1
self.top_p = 0.8
self.top_k = 8
self.rwkv_type = RWKVType.Music
def adjust_occurrence(self, occurrence: Dict, token: int):
pass
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
pass
def fix_tokens(self, tokens) -> List[int]:
return tokens
def run_rnn(
self, _tokens: List[str], newline_adj: int = 0
) -> Tuple[List[float], int]:
tokens = [int(x) for x in _tokens]
token_len = len(tokens)
self.model_tokens += tokens
out, self.model_state = self.model.forward(tokens, self.model_state)
return out, token_len
def delta_postprocess(self, delta: str) -> str:
return delta
def get_tokenizer(tokenizer_len: int):
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
if tokenizer_len < 2176:
2024-01-05 12:44:44 +08:00
return "abc_tokenizer"
if tokenizer_len < 20096:
return tokenizer_dir + "tokenizer-midipiano.json"
if tokenizer_len < 50277:
return tokenizer_dir + "tokenizer-midi.json"
elif tokenizer_len < 65536:
return tokenizer_dir + "20B_tokenizer.json"
else:
return "rwkv_vocab_v20230424"
2024-02-02 22:00:01 +08:00
def get_model_path(model_path: str) -> str:
if os.path.isabs(model_path):
return model_path
working_dir: pathlib.Path = pathlib.Path(os.path.abspath(os.getcwd()))
parent_paths: List[pathlib.Path] = [
working_dir, # [cwd](RWKV-Runner)/models/xxx
working_dir.parent, # [cwd](backend-python)/../models/xxx
pathlib.Path(
os.path.abspath(__file__)
).parent.parent, # backend-python/models/xxx
pathlib.Path(
os.path.abspath(__file__)
).parent.parent.parent, # RWKV-Runner/models/xxx
]
child_paths: List[Callable[[pathlib.Path], pathlib.Path]] = [
lambda p: p / model_path,
lambda p: p / "build" / "bin" / model_path, # for dev
]
for parent_path in parent_paths:
for child_path in child_paths:
full_path: pathlib.Path = child_path(parent_path)
if os.path.isfile(full_path):
return str(full_path)
return model_path
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
2024-02-02 22:00:01 +08:00
model = get_model_path(model)
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
2023-12-12 20:29:55 +08:00
rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
webgpu = global_var.get(global_var.Args).webgpu
2023-12-04 17:51:21 +08:00
if "midi" in model.lower() or "abc" in model.lower():
os.environ["RWKV_RESCALE_LAYER"] = "999"
# dynamic import to make RWKV_CUDA_ON work
if rwkv_beta:
2023-12-12 20:29:55 +08:00
print("Using rwkv-beta")
from rwkv_pip.beta.model import (
RWKV as Model,
)
2023-12-12 20:29:55 +08:00
elif rwkv_cpp:
print("Using rwkv.cpp, strategy is ignored")
from rwkv_pip.cpp.model import (
RWKV as Model,
)
elif webgpu:
print("Using webgpu")
from rwkv_pip.webgpu.model import (
RWKV as Model,
)
else:
from rwkv_pip.model import (
RWKV as Model,
)
from rwkv_pip.utils import PIPELINE
filename, _ = os.path.splitext(os.path.basename(model))
model = Model(model, strategy)
if not tokenizer:
tokenizer = get_tokenizer(len(model.w["emb.weight"]))
pipeline = PIPELINE(model, tokenizer)
rwkv_map: dict[str, Type[AbstractRWKV]] = {
"20B_tokenizer": TextRWKV,
"rwkv_vocab_v20230424": TextRWKV,
2024-01-05 12:44:44 +08:00
"tokenizer-midi": MusicMidiRWKV,
"tokenizer-midipiano": MusicMidiRWKV,
2024-01-05 12:44:44 +08:00
"abc_tokenizer": MusicAbcRWKV,
}
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
global_var.set(
global_var.Midi_Vocab_Config_Type,
2024-02-04 16:49:46 +08:00
(
global_var.MidiVocabConfig.Piano
if tokenizer_name == "tokenizer-midipiano"
else global_var.MidiVocabConfig.Default
),
)
rwkv: AbstractRWKV
if tokenizer_name in rwkv_map:
rwkv = rwkv_map[tokenizer_name](model, pipeline)
else:
tokenizer_name = tokenizer_name.lower()
if "music" in tokenizer_name or "midi" in tokenizer_name:
rwkv = MusicMidiRWKV(model, pipeline)
elif "abc" in tokenizer_name:
rwkv = MusicAbcRWKV(model, pipeline)
else:
rwkv = TextRWKV(model, pipeline)
rwkv.name = filename
2024-03-24 22:29:28 +08:00
rwkv.version = model.version
return rwkv
2023-05-17 11:39:00 +08:00
class ModelConfigBody(BaseModel):
2023-05-30 23:13:27 +08:00
max_tokens: int = Field(default=None, gt=0, le=102400)
2024-03-02 16:52:53 +08:00
temperature: float = Field(default=None, ge=0, le=3)
2023-05-30 23:13:27 +08:00
top_p: float = Field(default=None, ge=0, le=1)
presence_penalty: float = Field(default=None, ge=-2, le=2)
frequency_penalty: float = Field(default=None, ge=-2, le=2)
2024-02-03 22:03:10 +08:00
penalty_decay: float = Field(default=None, ge=0.99, le=0.999)
top_k: int = Field(default=None, ge=0, le=25)
2024-03-14 12:24:45 +08:00
global_penalty: bool = Field(
default=None,
description="When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.",
)
2023-05-17 11:39:00 +08:00
model_config = {
"json_schema_extra": {
2023-06-15 21:52:22 +08:00
"example": {
"max_tokens": 1000,
2024-02-03 22:03:10 +08:00
"temperature": 1,
"top_p": 0.3,
"presence_penalty": 0,
"frequency_penalty": 1,
"penalty_decay": 0.996,
2024-03-02 17:50:41 +08:00
"global_penalty": False,
2023-06-15 21:52:22 +08:00
}
}
}
2023-06-15 21:52:22 +08:00
2023-05-17 11:39:00 +08:00
2023-07-25 16:09:31 +08:00
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
2023-05-30 23:13:27 +08:00
if body.max_tokens is not None:
2023-05-17 11:39:00 +08:00
model.max_tokens_per_generation = body.max_tokens
2023-05-30 23:13:27 +08:00
if body.temperature is not None:
if body.temperature < 0.1:
model.temperature = 0.1
else:
model.temperature = body.temperature
2023-05-30 23:13:27 +08:00
if body.top_p is not None:
2023-05-17 11:39:00 +08:00
model.top_p = body.top_p
2023-05-30 23:13:27 +08:00
if body.presence_penalty is not None:
2023-05-17 11:39:00 +08:00
model.penalty_alpha_presence = body.presence_penalty
2023-05-30 23:13:27 +08:00
if body.frequency_penalty is not None:
2023-05-17 11:39:00 +08:00
model.penalty_alpha_frequency = body.frequency_penalty
2024-02-03 22:03:10 +08:00
if body.penalty_decay is not None:
model.penalty_decay = body.penalty_decay
if body.top_k is not None:
model.top_k = body.top_k
2024-03-02 17:50:41 +08:00
if body.global_penalty is not None:
model.global_penalty = body.global_penalty
2023-05-17 11:39:00 +08:00
2023-07-25 16:09:31 +08:00
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
2023-05-17 11:39:00 +08:00
return ModelConfigBody(
max_tokens=model.max_tokens_per_generation,
temperature=model.temperature,
top_p=model.top_p,
presence_penalty=model.penalty_alpha_presence,
frequency_penalty=model.penalty_alpha_frequency,
2024-02-03 22:03:10 +08:00
penalty_decay=model.penalty_decay,
top_k=model.top_k,
2024-03-02 17:50:41 +08:00
global_penalty=model.global_penalty,
2023-05-17 11:39:00 +08:00
)