2023-07-25 08:09:31 +00:00
|
|
|
|
from abc import ABC, abstractmethod
|
2023-07-31 14:46:13 +00:00
|
|
|
|
from enum import Enum, auto
|
2023-05-23 03:19:39 +00:00
|
|
|
|
import os
|
|
|
|
|
import pathlib
|
2023-05-28 15:52:38 +00:00
|
|
|
|
import copy
|
2023-07-25 08:10:22 +00:00
|
|
|
|
import re
|
2023-07-28 14:13:19 +00:00
|
|
|
|
from typing import Dict, Iterable, List, Tuple, Union
|
2023-06-12 05:41:51 +00:00
|
|
|
|
from utils.log import quick_log
|
2023-05-28 15:52:38 +00:00
|
|
|
|
from fastapi import HTTPException
|
2023-05-30 15:13:27 +00:00
|
|
|
|
from pydantic import BaseModel, Field
|
2023-06-19 14:51:06 +00:00
|
|
|
|
import numpy as np
|
2023-05-28 15:52:38 +00:00
|
|
|
|
from routes import state_cache
|
2023-08-14 14:07:15 +00:00
|
|
|
|
import global_var
|
2023-05-28 04:53:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
END_OF_TEXT = 0
|
2023-05-31 06:55:13 +00:00
|
|
|
|
END_OF_LINE_DOUBLE = 535
|
2023-05-28 04:53:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
|
|
|
|
|
|
|
|
|
|
2023-07-31 14:46:13 +00:00
|
|
|
|
class RWKVType(Enum):
|
|
|
|
|
Raven = auto()
|
|
|
|
|
World = auto()
|
|
|
|
|
Music = auto()
|
|
|
|
|
|
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
class AbstractRWKV(ABC):
|
|
|
|
|
def __init__(self, model: str, strategy: str, tokens_path: str):
|
2023-08-14 14:07:15 +00:00
|
|
|
|
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
|
|
|
|
|
|
|
|
|
# dynamic import to make RWKV_CUDA_ON work
|
|
|
|
|
if rwkv_beta:
|
|
|
|
|
from rwkv_pip.beta.model import (
|
|
|
|
|
RWKV as Model,
|
|
|
|
|
)
|
|
|
|
|
else:
|
2023-10-03 05:33:55 +00:00
|
|
|
|
from rwkv_pip.model import (
|
2023-08-14 14:07:15 +00:00
|
|
|
|
RWKV as Model,
|
|
|
|
|
)
|
2023-07-25 08:14:29 +00:00
|
|
|
|
from rwkv_pip.utils import PIPELINE
|
2023-05-28 04:53:14 +00:00
|
|
|
|
|
2023-06-19 14:30:49 +00:00
|
|
|
|
filename, _ = os.path.splitext(os.path.basename(model))
|
|
|
|
|
self.name = filename
|
2023-05-28 04:53:14 +00:00
|
|
|
|
self.model = Model(model, strategy)
|
|
|
|
|
self.pipeline = PIPELINE(self.model, tokens_path)
|
|
|
|
|
self.model_state = None
|
|
|
|
|
self.model_tokens = []
|
2023-07-31 14:46:13 +00:00
|
|
|
|
self.rwkv_type: RWKVType = None
|
2023-05-28 04:53:14 +00:00
|
|
|
|
|
|
|
|
|
self.max_tokens_per_generation = 500
|
|
|
|
|
self.temperature = 1
|
2023-07-25 08:09:31 +00:00
|
|
|
|
self.top_p = 0.3
|
|
|
|
|
self.top_k = 0
|
|
|
|
|
self.penalty_alpha_presence = 0
|
|
|
|
|
self.penalty_alpha_frequency = 1
|
2023-05-28 04:53:14 +00:00
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
@abstractmethod
|
|
|
|
|
def adjust_occurrence(self, occurrence: Dict, token: int):
|
|
|
|
|
pass
|
2023-05-28 04:53:14 +00:00
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
@abstractmethod
|
|
|
|
|
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
|
|
|
|
|
pass
|
2023-05-28 16:08:13 +00:00
|
|
|
|
|
2023-05-31 06:55:13 +00:00
|
|
|
|
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
|
2023-07-25 08:09:31 +00:00
|
|
|
|
@abstractmethod
|
|
|
|
|
def fix_tokens(self, tokens) -> List[int]:
|
|
|
|
|
pass
|
2023-05-31 06:55:13 +00:00
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
@abstractmethod
|
|
|
|
|
def run_rnn(
|
|
|
|
|
self, _tokens: List[str], newline_adj: int = 0
|
|
|
|
|
) -> Tuple[List[float], int]:
|
|
|
|
|
pass
|
2023-05-28 04:53:14 +00:00
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
@abstractmethod
|
|
|
|
|
def delta_postprocess(self, delta: str) -> str:
|
|
|
|
|
pass
|
2023-05-28 04:53:14 +00:00
|
|
|
|
|
2023-06-20 07:55:52 +00:00
|
|
|
|
def get_embedding(self, input: str, fast_mode: bool) -> Tuple[List[float], int]:
|
2023-06-19 14:51:06 +00:00
|
|
|
|
if fast_mode:
|
2023-07-25 08:09:31 +00:00
|
|
|
|
embedding, token_len = self.__fast_embedding(
|
2023-06-19 14:51:06 +00:00
|
|
|
|
self.fix_tokens(self.pipeline.encode(input)), None
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.model_state = None
|
|
|
|
|
self.model_tokens = []
|
2023-06-20 07:55:52 +00:00
|
|
|
|
_, token_len = self.run_rnn(self.fix_tokens(self.pipeline.encode(input)))
|
2023-07-25 12:30:43 +00:00
|
|
|
|
embedding = self.model_state[-11].tolist()
|
2023-06-19 14:51:06 +00:00
|
|
|
|
embedding = (embedding / np.linalg.norm(embedding)).tolist()
|
2023-06-20 07:55:52 +00:00
|
|
|
|
return embedding, token_len
|
2023-06-19 14:51:06 +00:00
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
def __fast_embedding(self, tokens: List[str], state):
|
2023-07-25 08:14:29 +00:00
|
|
|
|
import torch
|
|
|
|
|
|
2023-06-19 14:51:06 +00:00
|
|
|
|
tokens = [int(x) for x in tokens]
|
2023-06-20 07:55:52 +00:00
|
|
|
|
token_len = len(tokens)
|
2023-06-19 14:51:06 +00:00
|
|
|
|
self = self.model
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
w = self.w
|
|
|
|
|
args = self.args
|
|
|
|
|
|
|
|
|
|
if state == None:
|
|
|
|
|
state = [None] * args.n_layer * 5
|
|
|
|
|
for i in range(
|
|
|
|
|
args.n_layer
|
|
|
|
|
): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
|
|
|
|
|
dd = self.strategy[i]
|
|
|
|
|
dev = dd.device
|
|
|
|
|
atype = dd.atype
|
|
|
|
|
state[i * 5 + 0] = torch.zeros(
|
|
|
|
|
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
|
|
|
|
).contiguous()
|
|
|
|
|
state[i * 5 + 1] = torch.zeros(
|
|
|
|
|
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
|
|
|
|
|
).contiguous()
|
|
|
|
|
state[i * 5 + 2] = torch.zeros(
|
|
|
|
|
args.n_embd, dtype=torch.float, requires_grad=False, device=dev
|
|
|
|
|
).contiguous()
|
|
|
|
|
state[i * 5 + 3] = (
|
|
|
|
|
torch.zeros(
|
|
|
|
|
args.n_embd,
|
|
|
|
|
dtype=torch.float,
|
|
|
|
|
requires_grad=False,
|
|
|
|
|
device=dev,
|
|
|
|
|
).contiguous()
|
|
|
|
|
- 1e30
|
|
|
|
|
)
|
|
|
|
|
state[i * 5 + 4] = torch.zeros(
|
|
|
|
|
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
|
|
|
|
).contiguous()
|
|
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
seq_mode = len(tokens) > 1
|
|
|
|
|
|
|
|
|
|
x = w["emb.weight"][tokens if seq_mode else tokens[0]]
|
|
|
|
|
|
|
|
|
|
for i in range(args.n_layer):
|
|
|
|
|
bbb = f"blocks.{i}."
|
|
|
|
|
att = f"blocks.{i}.att."
|
|
|
|
|
ffn = f"blocks.{i}.ffn."
|
|
|
|
|
dd = self.strategy[i]
|
|
|
|
|
dev = dd.device
|
|
|
|
|
atype = dd.atype
|
|
|
|
|
wtype = dd.wtype
|
|
|
|
|
if seq_mode:
|
|
|
|
|
if "cuda" in str(dev) and os.environ["RWKV_CUDA_ON"] == "1":
|
|
|
|
|
ATT = (
|
|
|
|
|
self.cuda_att_seq
|
|
|
|
|
if wtype != torch.uint8
|
|
|
|
|
else self.cuda_att_seq_i8
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8
|
|
|
|
|
FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8
|
|
|
|
|
else:
|
|
|
|
|
ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8
|
|
|
|
|
FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8
|
|
|
|
|
|
|
|
|
|
x = x.to(dtype=atype, device=dev)
|
|
|
|
|
|
|
|
|
|
kw = w[f"{att}key.weight"]
|
|
|
|
|
vw = w[f"{att}value.weight"]
|
|
|
|
|
rw = w[f"{att}receptance.weight"]
|
|
|
|
|
ow = w[f"{att}output.weight"]
|
|
|
|
|
if dd.stream:
|
|
|
|
|
kw = kw.to(device=dev, non_blocking=True)
|
|
|
|
|
vw = vw.to(device=dev, non_blocking=True)
|
|
|
|
|
rw = rw.to(device=dev, non_blocking=True)
|
|
|
|
|
ow = ow.to(device=dev, non_blocking=True)
|
|
|
|
|
kmx = w[f"{att}key.weight_mx"] if wtype == torch.uint8 else x
|
|
|
|
|
krx = w[f"{att}key.weight_rx"] if wtype == torch.uint8 else x
|
|
|
|
|
kmy = w[f"{att}key.weight_my"] if wtype == torch.uint8 else x
|
|
|
|
|
kry = w[f"{att}key.weight_ry"] if wtype == torch.uint8 else x
|
|
|
|
|
vmx = w[f"{att}value.weight_mx"] if wtype == torch.uint8 else x
|
|
|
|
|
vrx = w[f"{att}value.weight_rx"] if wtype == torch.uint8 else x
|
|
|
|
|
vmy = w[f"{att}value.weight_my"] if wtype == torch.uint8 else x
|
|
|
|
|
vry = w[f"{att}value.weight_ry"] if wtype == torch.uint8 else x
|
|
|
|
|
rmx = w[f"{att}receptance.weight_mx"] if wtype == torch.uint8 else x
|
|
|
|
|
rrx = w[f"{att}receptance.weight_rx"] if wtype == torch.uint8 else x
|
|
|
|
|
rmy = w[f"{att}receptance.weight_my"] if wtype == torch.uint8 else x
|
|
|
|
|
rry = w[f"{att}receptance.weight_ry"] if wtype == torch.uint8 else x
|
|
|
|
|
omx = w[f"{att}output.weight_mx"] if wtype == torch.uint8 else x
|
|
|
|
|
orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
|
|
|
|
|
omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
|
|
|
|
|
ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
|
|
|
|
|
(
|
|
|
|
|
x,
|
|
|
|
|
state[i * 5 + 0],
|
|
|
|
|
state[i * 5 + 1],
|
|
|
|
|
state[i * 5 + 2],
|
|
|
|
|
state[i * 5 + 3],
|
|
|
|
|
) = ATT(
|
|
|
|
|
x,
|
|
|
|
|
state[i * 5 + 0],
|
|
|
|
|
state[i * 5 + 1],
|
|
|
|
|
state[i * 5 + 2],
|
|
|
|
|
state[i * 5 + 3],
|
|
|
|
|
w[f"{bbb}ln1.weight"],
|
|
|
|
|
w[f"{bbb}ln1.bias"],
|
|
|
|
|
w[f"{att}time_mix_k"],
|
|
|
|
|
w[f"{att}time_mix_v"],
|
|
|
|
|
w[f"{att}time_mix_r"],
|
|
|
|
|
w[f"{att}time_decay"],
|
|
|
|
|
w[f"{att}time_first"],
|
|
|
|
|
kw,
|
|
|
|
|
vw,
|
|
|
|
|
rw,
|
|
|
|
|
ow,
|
|
|
|
|
kmx,
|
|
|
|
|
krx,
|
|
|
|
|
kmy,
|
|
|
|
|
kry,
|
|
|
|
|
vmx,
|
|
|
|
|
vrx,
|
|
|
|
|
vmy,
|
|
|
|
|
vry,
|
|
|
|
|
rmx,
|
|
|
|
|
rrx,
|
|
|
|
|
rmy,
|
|
|
|
|
rry,
|
|
|
|
|
omx,
|
|
|
|
|
orx,
|
|
|
|
|
omy,
|
|
|
|
|
ory,
|
|
|
|
|
)
|
|
|
|
|
|
2023-06-20 07:55:52 +00:00
|
|
|
|
return state[0].tolist(), token_len
|
2023-06-19 14:51:06 +00:00
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
def generate(
|
2023-08-14 14:07:15 +00:00
|
|
|
|
self, prompt: str, stop: Union[str, List[str], None] = None
|
2023-07-25 08:09:31 +00:00
|
|
|
|
) -> Iterable[Tuple[str, str, int, int]]:
|
2023-06-12 05:41:51 +00:00
|
|
|
|
quick_log(None, None, "Generation Prompt:\n" + prompt)
|
2023-05-28 15:52:38 +00:00
|
|
|
|
cache = None
|
|
|
|
|
delta_prompt = prompt
|
|
|
|
|
try:
|
|
|
|
|
cache = state_cache.longest_prefix_state(
|
2023-06-09 12:46:19 +00:00
|
|
|
|
state_cache.LongestPrefixStateBody(prompt=prompt), None
|
2023-05-28 15:52:38 +00:00
|
|
|
|
)
|
|
|
|
|
except HTTPException:
|
|
|
|
|
pass
|
|
|
|
|
if cache is None or cache["prompt"] == "":
|
|
|
|
|
self.model_state = None
|
|
|
|
|
self.model_tokens = []
|
|
|
|
|
else:
|
|
|
|
|
delta_prompt = prompt[len(cache["prompt"]) :]
|
|
|
|
|
self.model_state = copy.deepcopy(cache["state"])
|
|
|
|
|
self.model_tokens = copy.deepcopy(cache["tokens"])
|
|
|
|
|
logits = copy.deepcopy(cache["logits"])
|
|
|
|
|
|
2023-06-20 07:55:52 +00:00
|
|
|
|
prompt_token_len = 0
|
2023-05-28 15:52:38 +00:00
|
|
|
|
if delta_prompt != "":
|
2023-06-20 07:55:52 +00:00
|
|
|
|
logits, prompt_token_len = self.run_rnn(
|
|
|
|
|
self.fix_tokens(self.pipeline.encode(delta_prompt))
|
|
|
|
|
)
|
2023-05-28 15:52:38 +00:00
|
|
|
|
try:
|
|
|
|
|
state_cache.add_state(
|
|
|
|
|
state_cache.AddStateBody(
|
|
|
|
|
prompt=prompt,
|
|
|
|
|
tokens=self.model_tokens,
|
|
|
|
|
state=self.model_state,
|
|
|
|
|
logits=logits,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
except HTTPException:
|
|
|
|
|
pass
|
|
|
|
|
|
2023-05-28 04:53:14 +00:00
|
|
|
|
begin = len(self.model_tokens)
|
|
|
|
|
out_last = begin
|
|
|
|
|
|
|
|
|
|
occurrence: Dict = {}
|
|
|
|
|
|
2023-06-20 07:55:52 +00:00
|
|
|
|
completion_token_len = 0
|
2023-05-28 04:53:14 +00:00
|
|
|
|
response = ""
|
|
|
|
|
for i in range(self.max_tokens_per_generation):
|
2023-07-25 08:09:31 +00:00
|
|
|
|
self.adjust_forward_logits(logits, occurrence, i)
|
|
|
|
|
|
2023-05-28 04:53:14 +00:00
|
|
|
|
token = self.pipeline.sample_logits(
|
2023-07-25 08:09:31 +00:00
|
|
|
|
logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
|
2023-05-28 04:53:14 +00:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if token == END_OF_TEXT:
|
2023-06-20 07:55:52 +00:00
|
|
|
|
yield response, "", prompt_token_len, completion_token_len
|
2023-05-28 04:53:14 +00:00
|
|
|
|
break
|
2023-07-25 08:09:31 +00:00
|
|
|
|
|
|
|
|
|
self.adjust_occurrence(occurrence, token)
|
2023-05-28 04:53:14 +00:00
|
|
|
|
|
2023-06-20 07:55:52 +00:00
|
|
|
|
logits, _ = self.run_rnn([token])
|
|
|
|
|
completion_token_len = completion_token_len + 1
|
2023-07-25 08:09:31 +00:00
|
|
|
|
delta: str = self.delta_postprocess(
|
|
|
|
|
self.pipeline.decode(self.model_tokens[out_last:])
|
|
|
|
|
)
|
2023-05-28 04:53:14 +00:00
|
|
|
|
if "\ufffd" not in delta: # avoid utf-8 display issues
|
|
|
|
|
response += delta
|
|
|
|
|
if stop is not None:
|
2023-07-25 08:10:22 +00:00
|
|
|
|
if type(stop) == str:
|
|
|
|
|
if stop in response:
|
|
|
|
|
try:
|
|
|
|
|
state_cache.add_state(
|
|
|
|
|
state_cache.AddStateBody(
|
|
|
|
|
prompt=prompt + response,
|
|
|
|
|
tokens=self.model_tokens,
|
|
|
|
|
state=self.model_state,
|
|
|
|
|
logits=logits,
|
|
|
|
|
)
|
2023-05-28 15:52:38 +00:00
|
|
|
|
)
|
2023-07-25 08:10:22 +00:00
|
|
|
|
except HTTPException:
|
|
|
|
|
pass
|
|
|
|
|
response = response.split(stop)[0]
|
|
|
|
|
yield response, "", prompt_token_len, completion_token_len
|
|
|
|
|
break
|
|
|
|
|
elif type(stop) == list:
|
|
|
|
|
stop_exist_regex = "|".join(stop)
|
|
|
|
|
matched = re.search(stop_exist_regex, response)
|
|
|
|
|
if matched:
|
|
|
|
|
try:
|
|
|
|
|
state_cache.add_state(
|
|
|
|
|
state_cache.AddStateBody(
|
|
|
|
|
prompt=prompt + response,
|
|
|
|
|
tokens=self.model_tokens,
|
|
|
|
|
state=self.model_state,
|
|
|
|
|
logits=logits,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
except HTTPException:
|
|
|
|
|
pass
|
|
|
|
|
response = response.split(matched.group())[0]
|
|
|
|
|
yield response, "", prompt_token_len, completion_token_len
|
|
|
|
|
break
|
2023-05-28 04:53:14 +00:00
|
|
|
|
out_last = begin + i + 1
|
2023-05-28 15:52:38 +00:00
|
|
|
|
if i == self.max_tokens_per_generation - 1:
|
|
|
|
|
try:
|
|
|
|
|
state_cache.add_state(
|
|
|
|
|
state_cache.AddStateBody(
|
|
|
|
|
prompt=prompt + response,
|
|
|
|
|
tokens=self.model_tokens,
|
|
|
|
|
state=self.model_state,
|
|
|
|
|
logits=logits,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
except HTTPException:
|
|
|
|
|
pass
|
2023-06-20 07:55:52 +00:00
|
|
|
|
yield response, delta, prompt_token_len, completion_token_len
|
2023-05-17 03:39:00 +00:00
|
|
|
|
|
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
class TextRWKV(AbstractRWKV):
|
|
|
|
|
def __init__(self, model: str, strategy: str, tokens_path: str) -> None:
|
|
|
|
|
super().__init__(model, strategy, tokens_path)
|
|
|
|
|
|
|
|
|
|
self.CHUNK_LEN = 256
|
|
|
|
|
|
|
|
|
|
self.max_tokens_per_generation = 500
|
|
|
|
|
self.temperature = 1
|
|
|
|
|
self.top_p = 0.3
|
|
|
|
|
self.top_k = 0
|
|
|
|
|
self.penalty_alpha_presence = 0
|
|
|
|
|
self.penalty_alpha_frequency = 1
|
|
|
|
|
|
|
|
|
|
self.interface = ":"
|
|
|
|
|
if "world" in self.name.lower():
|
2023-07-31 14:46:13 +00:00
|
|
|
|
self.rwkv_type = RWKVType.World
|
2023-07-25 08:09:31 +00:00
|
|
|
|
self.user = "Question"
|
|
|
|
|
self.bot = "Answer"
|
|
|
|
|
self.END_OF_LINE = 11
|
|
|
|
|
else:
|
2023-07-31 14:46:13 +00:00
|
|
|
|
self.rwkv_type = RWKVType.Raven
|
2023-07-25 08:09:31 +00:00
|
|
|
|
self.user = "Bob"
|
|
|
|
|
self.bot = "Alice"
|
|
|
|
|
self.END_OF_LINE = 187
|
|
|
|
|
|
|
|
|
|
self.AVOID_REPEAT_TOKENS = []
|
|
|
|
|
AVOID_REPEAT = ",:?!"
|
|
|
|
|
for i in AVOID_REPEAT:
|
|
|
|
|
dd = self.pipeline.encode(i)
|
|
|
|
|
assert len(dd) == 1
|
|
|
|
|
self.AVOID_REPEAT_TOKENS += dd
|
|
|
|
|
|
|
|
|
|
self.__preload()
|
|
|
|
|
|
|
|
|
|
def adjust_occurrence(self, occurrence: Dict, token: int):
|
|
|
|
|
for xxx in occurrence:
|
|
|
|
|
occurrence[xxx] *= 0.996
|
|
|
|
|
if token not in occurrence:
|
|
|
|
|
occurrence[token] = 1
|
|
|
|
|
else:
|
|
|
|
|
occurrence[token] += 1
|
|
|
|
|
|
|
|
|
|
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
|
|
|
|
|
for n in occurrence:
|
|
|
|
|
logits[n] -= (
|
|
|
|
|
self.penalty_alpha_presence
|
|
|
|
|
+ occurrence[n] * self.penalty_alpha_frequency
|
|
|
|
|
)
|
|
|
|
|
|
2023-07-31 14:02:28 +00:00
|
|
|
|
if i == 0:
|
|
|
|
|
for token in self.model_tokens:
|
|
|
|
|
token = int(token)
|
|
|
|
|
for xxx in occurrence:
|
|
|
|
|
occurrence[xxx] *= 0.996
|
|
|
|
|
if token not in occurrence:
|
|
|
|
|
occurrence[token] = 1
|
|
|
|
|
else:
|
|
|
|
|
occurrence[token] += 1
|
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
|
|
|
|
|
def fix_tokens(self, tokens) -> List[int]:
|
2023-07-31 14:46:13 +00:00
|
|
|
|
if self.rwkv_type == RWKVType.World:
|
2023-07-25 08:09:31 +00:00
|
|
|
|
return tokens
|
|
|
|
|
if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
|
|
|
|
|
tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
|
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
|
|
def run_rnn(
|
|
|
|
|
self, _tokens: List[str], newline_adj: int = 0
|
|
|
|
|
) -> Tuple[List[float], int]:
|
|
|
|
|
tokens = [int(x) for x in _tokens]
|
|
|
|
|
token_len = len(tokens)
|
|
|
|
|
self.model_tokens += tokens
|
|
|
|
|
|
|
|
|
|
while len(tokens) > 0:
|
|
|
|
|
out, self.model_state = self.model.forward(
|
|
|
|
|
tokens[: self.CHUNK_LEN], self.model_state
|
|
|
|
|
)
|
|
|
|
|
tokens = tokens[self.CHUNK_LEN :]
|
|
|
|
|
|
|
|
|
|
out[self.END_OF_LINE] += newline_adj # adjust \n probability
|
|
|
|
|
|
|
|
|
|
if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS:
|
|
|
|
|
out[self.model_tokens[-1]] = -999999999
|
|
|
|
|
return out, token_len
|
|
|
|
|
|
|
|
|
|
def delta_postprocess(self, delta: str) -> str:
|
|
|
|
|
return delta
|
|
|
|
|
|
|
|
|
|
def __preload(self):
|
|
|
|
|
interface = self.interface
|
|
|
|
|
user = self.user
|
|
|
|
|
bot = self.bot
|
|
|
|
|
preset_system = (
|
|
|
|
|
f"""
|
|
|
|
|
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
|
|
|
|
|
{bot} is very intelligent, creative and friendly. \
|
|
|
|
|
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
|
|
|
|
|
{bot} likes to tell {user} a lot about herself and her opinions. \
|
|
|
|
|
{bot} usually gives {user} kind, helpful and informative advices.\n
|
|
|
|
|
"""
|
2023-07-31 14:46:13 +00:00
|
|
|
|
if self.rwkv_type == RWKVType.Raven
|
2023-08-14 14:07:15 +00:00
|
|
|
|
else (
|
|
|
|
|
f"{user}{interface} hi\n\n{bot}{interface} Hi. "
|
|
|
|
|
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
|
|
|
|
|
)
|
2023-07-25 08:09:31 +00:00
|
|
|
|
)
|
|
|
|
|
logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
|
|
|
|
|
try:
|
|
|
|
|
state_cache.add_state(
|
|
|
|
|
state_cache.AddStateBody(
|
|
|
|
|
prompt=preset_system,
|
|
|
|
|
tokens=self.model_tokens,
|
|
|
|
|
state=self.model_state,
|
|
|
|
|
logits=logits,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
except HTTPException:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MusicRWKV(AbstractRWKV):
|
|
|
|
|
def __init__(self, model: str, strategy: str, tokens_path: str):
|
|
|
|
|
super().__init__(model, strategy, tokens_path)
|
|
|
|
|
|
|
|
|
|
self.max_tokens_per_generation = 500
|
|
|
|
|
self.temperature = 1
|
|
|
|
|
self.top_p = 0.8
|
|
|
|
|
self.top_k = 8
|
|
|
|
|
|
2023-07-31 14:46:13 +00:00
|
|
|
|
self.rwkv_type = RWKVType.Music
|
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
def adjust_occurrence(self, occurrence: Dict, token: int):
|
|
|
|
|
for n in occurrence:
|
|
|
|
|
occurrence[n] *= 0.997 #### decay repetition penalty
|
|
|
|
|
if token >= 128 or token == 127:
|
|
|
|
|
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
|
|
|
|
|
else:
|
|
|
|
|
occurrence[token] = 0.3 + (occurrence[token] if token in occurrence else 0)
|
|
|
|
|
|
|
|
|
|
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
|
|
|
|
|
for n in occurrence:
|
|
|
|
|
logits[n] -= 0 + occurrence[n] * 0.5
|
|
|
|
|
|
|
|
|
|
logits[0] += (i - 2000) / 500 # try not to be too short or too long
|
|
|
|
|
logits[127] -= 1 # avoid "t125"
|
|
|
|
|
|
|
|
|
|
def fix_tokens(self, tokens) -> List[int]:
|
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
|
|
def run_rnn(
|
|
|
|
|
self, _tokens: List[str], newline_adj: int = 0
|
|
|
|
|
) -> Tuple[List[float], int]:
|
|
|
|
|
tokens = [int(x) for x in _tokens]
|
|
|
|
|
token_len = len(tokens)
|
|
|
|
|
self.model_tokens += tokens
|
|
|
|
|
out, self.model_state = self.model.forward(tokens, self.model_state)
|
|
|
|
|
return out, token_len
|
|
|
|
|
|
|
|
|
|
def delta_postprocess(self, delta: str) -> str:
|
|
|
|
|
return " " + delta
|
|
|
|
|
|
|
|
|
|
|
2023-05-17 03:39:00 +00:00
|
|
|
|
class ModelConfigBody(BaseModel):
|
2023-05-30 15:13:27 +00:00
|
|
|
|
max_tokens: int = Field(default=None, gt=0, le=102400)
|
|
|
|
|
temperature: float = Field(default=None, ge=0, le=2)
|
|
|
|
|
top_p: float = Field(default=None, ge=0, le=1)
|
|
|
|
|
presence_penalty: float = Field(default=None, ge=-2, le=2)
|
|
|
|
|
frequency_penalty: float = Field(default=None, ge=-2, le=2)
|
2023-05-17 03:39:00 +00:00
|
|
|
|
|
2023-06-15 13:52:22 +00:00
|
|
|
|
class Config:
|
|
|
|
|
schema_extra = {
|
|
|
|
|
"example": {
|
|
|
|
|
"max_tokens": 1000,
|
|
|
|
|
"temperature": 1.2,
|
|
|
|
|
"top_p": 0.5,
|
|
|
|
|
"presence_penalty": 0.4,
|
|
|
|
|
"frequency_penalty": 0.4,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-05-17 03:39:00 +00:00
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
|
2023-05-30 15:13:27 +00:00
|
|
|
|
if body.max_tokens is not None:
|
2023-05-17 03:39:00 +00:00
|
|
|
|
model.max_tokens_per_generation = body.max_tokens
|
2023-05-30 15:13:27 +00:00
|
|
|
|
if body.temperature is not None:
|
2023-06-04 03:53:33 +00:00
|
|
|
|
if body.temperature < 0.1:
|
|
|
|
|
model.temperature = 0.1
|
|
|
|
|
else:
|
|
|
|
|
model.temperature = body.temperature
|
2023-05-30 15:13:27 +00:00
|
|
|
|
if body.top_p is not None:
|
2023-05-17 03:39:00 +00:00
|
|
|
|
model.top_p = body.top_p
|
2023-05-30 15:13:27 +00:00
|
|
|
|
if body.presence_penalty is not None:
|
2023-05-17 03:39:00 +00:00
|
|
|
|
model.penalty_alpha_presence = body.presence_penalty
|
2023-05-30 15:13:27 +00:00
|
|
|
|
if body.frequency_penalty is not None:
|
2023-05-17 03:39:00 +00:00
|
|
|
|
model.penalty_alpha_frequency = body.frequency_penalty
|
|
|
|
|
|
|
|
|
|
|
2023-07-25 08:09:31 +00:00
|
|
|
|
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
2023-05-17 03:39:00 +00:00
|
|
|
|
return ModelConfigBody(
|
|
|
|
|
max_tokens=model.max_tokens_per_generation,
|
|
|
|
|
temperature=model.temperature,
|
|
|
|
|
top_p=model.top_p,
|
|
|
|
|
presence_penalty=model.penalty_alpha_presence,
|
|
|
|
|
frequency_penalty=model.penalty_alpha_frequency,
|
|
|
|
|
)
|