support for rwkv-4-world

This commit is contained in:
josc146 2023-05-28 12:53:14 +08:00
parent b7fb8ed898
commit 94971bb666
8 changed files with 65918 additions and 65 deletions

View File

@ -11,10 +11,6 @@ import global_var
router = APIRouter() router = APIRouter()
interface = ":"
user = "Bob"
bot = "Alice"
class Message(BaseModel): class Message(BaseModel):
role: str role: str
@ -44,17 +40,27 @@ async def chat_completions(body: ChatCompletionBody, request: Request):
else: else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found") raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
completion_text = f""" interface = model.interface
user = model.user
bot = model.bot
completion_text = (
f"""
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \ The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
{bot} is very intelligent, creative and friendly. \ {bot} is very intelligent, creative and friendly. \
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \ {bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
{bot} likes to tell {user} a lot about herself and her opinions. \ {bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n {bot} usually gives {user} kind, helpful and informative advices.\n
""" """
if user == "Bob"
else ""
)
for message in body.messages: for message in body.messages:
if message.role == "system": if message.role == "system":
completion_text = ( completion_text = (
f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. " f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. "
if user == "Bob"
else ""
+ message.content.replace("\\n", "\n") + message.content.replace("\\n", "\n")
.replace("\r\n", "\n") .replace("\r\n", "\n")
.replace("\n\n", "\n") .replace("\n\n", "\n")
@ -101,8 +107,7 @@ The following is a coherent verbose detailed conversation between a girl named {
set_rwkv_config(model, global_var.get(global_var.Model_Config)) set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body) set_rwkv_config(model, body)
if body.stream: if body.stream:
for response, delta in rwkv_generate( for response, delta in model.generate(
model,
completion_text, completion_text,
stop=f"\n\n{user}" if body.stop is None else body.stop, stop=f"\n\n{user}" if body.stop is None else body.stop,
): ):
@ -141,8 +146,7 @@ The following is a coherent verbose detailed conversation between a girl named {
yield "[DONE]" yield "[DONE]"
else: else:
response = None response = None
for response, delta in rwkv_generate( for response, delta in model.generate(
model,
completion_text, completion_text,
stop=f"\n\n{user}" if body.stop is None else body.stop, stop=f"\n\n{user}" if body.stop is None else body.stop,
): ):
@ -186,7 +190,7 @@ async def completions(body: CompletionBody, request: Request):
model: RWKV = global_var.get(global_var.Model) model: RWKV = global_var.get(global_var.Model)
if model is None: if model is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded") raise HTTPException(status.HTTP_400_BAD_REQUEST, "model not loaded")
if body.prompt is None or body.prompt == "": if body.prompt is None or body.prompt == "":
raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found") raise HTTPException(status.HTTP_400_BAD_REQUEST, "prompt not found")
@ -200,9 +204,7 @@ async def completions(body: CompletionBody, request: Request):
set_rwkv_config(model, global_var.get(global_var.Model_Config)) set_rwkv_config(model, global_var.get(global_var.Model_Config))
set_rwkv_config(model, body) set_rwkv_config(model, body)
if body.stream: if body.stream:
for response, delta in rwkv_generate( for response, delta in model.generate(body.prompt, stop=body.stop):
model, body.prompt, stop=body.stop
):
if await request.is_disconnected(): if await request.is_disconnected():
break break
yield json.dumps( yield json.dumps(
@ -238,9 +240,7 @@ async def completions(body: CompletionBody, request: Request):
yield "[DONE]" yield "[DONE]"
else: else:
response = None response = None
for response, delta in rwkv_generate( for response, delta in model.generate(body.prompt, stop=body.stop):
model, body.prompt, stop=body.stop
):
if await request.is_disconnected(): if await request.is_disconnected():
break break
# torch_gc() # torch_gc()

View File

@ -11,6 +11,19 @@ import GPUtil
router = APIRouter() router = APIRouter()
def get_tokens_path(model_path: str):
model_path = model_path.lower()
default_tokens_path = (
f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
)
if "raven" in model_path:
return default_tokens_path
elif "world" in model_path:
return "rwkv_vocab_v20230424"
else:
return default_tokens_path
class SwitchModelBody(BaseModel): class SwitchModelBody(BaseModel):
model: str model: str
strategy: str strategy: str
@ -36,7 +49,7 @@ def switch_model(body: SwitchModelBody, response: Response):
RWKV( RWKV(
model=body.model, model=body.model,
strategy=body.strategy, strategy=body.strategy,
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json", tokens_path=get_tokens_path(body.model),
), ),
) )
except Exception as e: except Exception as e:

View File

@ -0,0 +1,106 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
class TRIE:
__slots__ = tuple("ch,to,values,front".split(","))
to: list
values: set
def __init__(self, front=None, ch=None):
self.ch = ch
self.to = [None for ch in range(256)]
self.values = set()
self.front = front
def __repr__(self):
fr = self
ret = []
while fr != None:
if fr.ch != None:
ret.append(fr.ch)
fr = fr.front
return "<TRIE %s %s>" % (ret[::-1], self.values)
def add(self, key: bytes, idx: int = 0, val=None):
if idx == len(key):
if val is None:
val = key
self.values.add(val)
return self
ch = key[idx]
if self.to[ch] is None:
self.to[ch] = TRIE(front=self, ch=ch)
return self.to[ch].add(key, idx=idx + 1, val=val)
def find_longest(self, key: bytes, idx: int = 0):
u: TRIE = self
ch: int = key[idx]
while u.to[ch] is not None:
u = u.to[ch]
idx += 1
if u.values:
ret = idx, u, u.values
if idx == len(key):
break
ch = key[idx]
return ret
class TRIE_TOKENIZER:
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
with open(file_name, "r", encoding="utf-8") as f:
lines = f.readlines()
for l in lines:
idx = int(l[: l.index(" ")])
x = eval(l[l.index(" ") : l.rindex(" ")])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(" ") :])
sorted += [x]
self.idx2token[idx] = x
self.token2idx = {}
for k, v in self.idx2token.items():
self.token2idx[v] = int(k)
self.root = TRIE()
for t, i in self.token2idx.items():
_ = self.root.add(t, val=(t, i))
def encodeBytes(self, src: bytes) -> list[int]:
idx: int = 0
tokens: list[int] = []
while idx < len(src):
_idx: int = idx
idx, _, values = self.root.find_longest(src, idx)
assert idx != _idx
_, token = next(iter(values))
tokens.append(token)
return tokens
def decodeBytes(self, tokens):
return b"".join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src):
return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens):
try:
return self.decodeBytes(tokens).decode("utf-8")
except:
return "\ufffd" # bad utf-8
def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode("utf-8")
except:
pass
print(f"{repr(s)}{i}", end=" ")
print()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,142 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, sys
import numpy as np
import torch
from torch.nn import functional as F
class PIPELINE_ARGS:
def __init__(
self,
temperature=1.0,
top_p=0.85,
top_k=0,
alpha_frequency=0.2,
alpha_presence=0.2,
token_ban=[],
token_stop=[],
chunk_len=256,
):
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
self.token_ban = token_ban # ban the generation of some tokens
self.token_stop = token_stop # stop generation whenever you see any token here
self.chunk_len = (
chunk_len # split input into chunks to save VRAM (shorter -> slower)
)
class PIPELINE:
def __init__(self, model, WORD_NAME):
self.model = model
if WORD_NAME == "cl100k_base":
import tiktoken
self.tokenizer = tiktoken.get_encoding(WORD_NAME)
elif WORD_NAME == "rwkv_vocab_v20230424":
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from rwkv_tokenizer import TRIE_TOKENIZER
self.tokenizer = TRIE_TOKENIZER(
os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
)
else:
from tokenizers import Tokenizer
self.tokenizer = Tokenizer.from_file(WORD_NAME)
def refine_context(self, context):
context = context.strip().split("\n")
for c in range(len(context)):
context[c] = context[c].strip().strip("\u3000").strip("\r")
context = list(filter(lambda c: c != "", context))
context = "\n" + ("\n".join(context)).strip()
if context == "":
context = "\n"
return context
def encode(self, x):
if "Tokenizer" in str(type(self.tokenizer)):
return self.tokenizer.encode(x).ids
else:
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
if probs.device == torch.device("cpu"):
probs = probs.numpy()
sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return int(out)
else:
sorted_ids = torch.argsort(probs)
sorted_probs = probs[sorted_ids]
sorted_probs = torch.flip(sorted_probs, dims=(0,))
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return int(out)
def generate(
self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None
):
all_tokens = []
out_last = 0
out_str = ""
occurrence = {}
for i in range(token_count):
# forward & adjust prob.
tokens = self.encode(ctx) if i == 0 else [token]
while len(tokens) > 0:
out, state = self.model.forward(tokens[: args.chunk_len], state)
tokens = tokens[args.chunk_len :]
for n in args.token_ban:
out[n] = -float("inf")
for n in occurrence:
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
# sampler
token = self.sample_logits(
out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k
)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
# output
tmp = self.decode(all_tokens[out_last:])
if "\ufffd" not in tmp: # is valid utf-8 string?
if callback:
callback(tmp)
out_str += tmp
out_last = i + 1
return out_str

View File

@ -1,8 +1,103 @@
import os import os
import pathlib import pathlib
from typing import Dict from typing import Dict, List
from langchain.llms import RWKV
from pydantic import BaseModel from pydantic import BaseModel
from rwkv_pip.utils import PIPELINE
END_OF_TEXT = 0
END_OF_LINE = 187
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
class RWKV:
def __init__(self, model: str, strategy: str, tokens_path: str) -> None:
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
self.model = Model(model, strategy)
self.pipeline = PIPELINE(self.model, tokens_path)
self.model_state = None
self.model_tokens = []
self.CHUNK_LEN = 256
self.max_tokens_per_generation = 500
self.temperature = 1
self.top_p = 0.5
self.penalty_alpha_presence = 0.4
self.penalty_alpha_frequency = 0.4
self.interface = ":"
if "rwkv_vocab" in tokens_path:
self.user = "Human"
self.bot = "Bot"
else:
self.user = "Bob"
self.bot = "Alice"
self.AVOID_REPEAT_TOKENS = []
AVOID_REPEAT = ""
for i in AVOID_REPEAT:
dd = self.pipeline.encode(i)
assert len(dd) == 1
self.AVOID_REPEAT_TOKENS += dd
def run_rnn(self, _tokens: List[str], newline_adj: int = 0):
tokens = [int(x) for x in _tokens]
self.model_tokens += tokens
while len(tokens) > 0:
out, self.model_state = self.model.forward(
tokens[: self.CHUNK_LEN], self.model_state
)
tokens = tokens[self.CHUNK_LEN :]
out[END_OF_LINE] += newline_adj # adjust \n probability
if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS:
out[self.model_tokens[-1]] = -999999999
return out
def generate(self, prompt: str, stop: str = None):
self.model_state = None
self.model_tokens = []
logits = self.run_rnn(self.pipeline.encode(prompt))
begin = len(self.model_tokens)
out_last = begin
occurrence: Dict = {}
response = ""
for i in range(self.max_tokens_per_generation):
for n in occurrence:
logits[n] -= (
self.penalty_alpha_presence
+ occurrence[n] * self.penalty_alpha_frequency
)
token = self.pipeline.sample_logits(
logits, temperature=self.temperature, top_p=self.top_p
)
if token == END_OF_TEXT:
break
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
logits = self.run_rnn([token])
delta: str = self.pipeline.decode(self.model_tokens[out_last:])
if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta
if stop is not None:
if stop in response:
response = response.split(stop)[0]
yield response, ""
break
out_last = begin + i + 1
yield response, delta
class ModelConfigBody(BaseModel): class ModelConfigBody(BaseModel):
@ -34,47 +129,3 @@ def get_rwkv_config(model: RWKV) -> ModelConfigBody:
presence_penalty=model.penalty_alpha_presence, presence_penalty=model.penalty_alpha_presence,
frequency_penalty=model.penalty_alpha_frequency, frequency_penalty=model.penalty_alpha_frequency,
) )
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
model.model_state = None
model.model_tokens = []
logits = model.run_rnn(model.tokenizer.encode(prompt).ids)
begin = len(model.model_tokens)
out_last = begin
occurrence: Dict = {}
response = ""
for i in range(model.max_tokens_per_generation):
for n in occurrence:
logits[n] -= (
model.penalty_alpha_presence
+ occurrence[n] * model.penalty_alpha_frequency
)
token = model.pipeline.sample_logits(
logits, temperature=model.temperature, top_p=model.top_p
)
END_OF_TEXT = 0
if token == END_OF_TEXT:
break
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
logits = model.run_rnn([token])
delta: str = model.tokenizer.decode(model.model_tokens[out_last:])
if "\ufffd" not in delta: # avoid utf-8 display issues
response += delta
if stop is not None:
if stop in response:
response = response.split(stop)[0]
yield response, ""
break
yield response, delta
out_last = begin + i + 1

View File

@ -54,6 +54,14 @@
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/utils/torch.py", "url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/utils/torch.py",
"path": "backend-python/utils/torch.py" "path": "backend-python/utils/torch.py"
}, },
{
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/rwkv_pip/rwkv_tokenizer.py",
"path": "backend-python/rwkv_pip/rwkv_tokenizer.py"
},
{
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/rwkv_pip/utils.py",
"path": "backend-python/rwkv_pip/utils.py"
},
{ {
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd", "url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd",
"path": "backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd" "path": "backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd"
@ -67,8 +75,12 @@
"path": "backend-python/wkv_cuda_utils/wkv_cuda_model.py" "path": "backend-python/wkv_cuda_utils/wkv_cuda_model.py"
}, },
{ {
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/20B_tokenizer.json", "url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/rwkv_pip/rwkv_vocab_v20230424.txt",
"path": "backend-python/20B_tokenizer.json" "path": "backend-python/rwkv_pip/rwkv_vocab_v20230424.txt"
},
{
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/rwkv_pip/20B_tokenizer.json",
"path": "backend-python/rwkv_pip/20B_tokenizer.json"
}, },
{ {
"url": "https://cdn.jsdelivr.net/gh/pypa/get-pip/public/get-pip.py", "url": "https://cdn.jsdelivr.net/gh/pypa/get-pip/public/get-pip.py",