support for rwkv-4-world
This commit is contained in:
parent
b7fb8ed898
commit
94971bb666
@ -11,10 +11,6 @@ import global_var
|
|||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
interface = ":"
|
|
||||||
user = "Bob"
|
|
||||||
bot = "Alice"
|
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
@ -44,17 +40,27 @@ async def chat_completions(body: ChatCompletionBody, request: Request):
|
|||||||
else:
|
else:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "no question found")
|
||||||
|
|
||||||
completion_text = f"""
|
interface = model.interface
|
||||||
|
user = model.user
|
||||||
|
bot = model.bot
|
||||||
|
|
||||||
|
completion_text = (
|
||||||
|
f"""
|
||||||
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
|
The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. \
|
||||||
{bot} is very intelligent, creative and friendly. \
|
{bot} is very intelligent, creative and friendly. \
|
||||||
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
|
{bot} is unlikely to disagree with {user}, and {bot} doesn't like to ask {user} questions. \
|
||||||
{bot} likes to tell {user} a lot about herself and her opinions. \
|
{bot} likes to tell {user} a lot about herself and her opinions. \
|
||||||
{bot} usually gives {user} kind, helpful and informative advices.\n
|
{bot} usually gives {user} kind, helpful and informative advices.\n
|
||||||
"""
|
"""
|
||||||
|
if user == "Bob"
|
||||||
|
else ""
|
||||||
|
)
|
||||||
for message in body.messages:
|
for message in body.messages:
|
||||||
if message.role == "system":
|
if message.role == "system":
|
||||||
completion_text = (
|
completion_text = (
|
||||||
f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. "
|
f"The following is a coherent verbose detailed conversation between a girl named {bot} and her friend {user}. "
|
||||||
|
if user == "Bob"
|
||||||
|
else ""
|
||||||
+ message.content.replace("\\n", "\n")
|
+ message.content.replace("\\n", "\n")
|
||||||
.replace("\r\n", "\n")
|
.replace("\r\n", "\n")
|
||||||
.replace("\n\n", "\n")
|
.replace("\n\n", "\n")
|
||||||
@ -101,8 +107,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
|||||||
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
||||||
set_rwkv_config(model, body)
|
set_rwkv_config(model, body)
|
||||||
if body.stream:
|
if body.stream:
|
||||||
for response, delta in rwkv_generate(
|
for response, delta in model.generate(
|
||||||
model,
|
|
||||||
completion_text,
|
completion_text,
|
||||||
stop=f"\n\n{user}" if body.stop is None else body.stop,
|
stop=f"\n\n{user}" if body.stop is None else body.stop,
|
||||||
):
|
):
|
||||||
@ -141,8 +146,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
|||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
else:
|
else:
|
||||||
response = None
|
response = None
|
||||||
for response, delta in rwkv_generate(
|
for response, delta in model.generate(
|
||||||
model,
|
|
||||||
completion_text,
|
completion_text,
|
||||||
stop=f"\n\n{user}" if body.stop is None else body.stop,
|
stop=f"\n\n{user}" if body.stop is None else body.stop,
|
||||||
):
|
):
|
||||||
@ -200,9 +204,7 @@ async def completions(body: CompletionBody, request: Request):
|
|||||||
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
set_rwkv_config(model, global_var.get(global_var.Model_Config))
|
||||||
set_rwkv_config(model, body)
|
set_rwkv_config(model, body)
|
||||||
if body.stream:
|
if body.stream:
|
||||||
for response, delta in rwkv_generate(
|
for response, delta in model.generate(body.prompt, stop=body.stop):
|
||||||
model, body.prompt, stop=body.stop
|
|
||||||
):
|
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
break
|
break
|
||||||
yield json.dumps(
|
yield json.dumps(
|
||||||
@ -238,9 +240,7 @@ async def completions(body: CompletionBody, request: Request):
|
|||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
else:
|
else:
|
||||||
response = None
|
response = None
|
||||||
for response, delta in rwkv_generate(
|
for response, delta in model.generate(body.prompt, stop=body.stop):
|
||||||
model, body.prompt, stop=body.stop
|
|
||||||
):
|
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
break
|
break
|
||||||
# torch_gc()
|
# torch_gc()
|
||||||
|
@ -11,6 +11,19 @@ import GPUtil
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokens_path(model_path: str):
|
||||||
|
model_path = model_path.lower()
|
||||||
|
default_tokens_path = (
|
||||||
|
f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json"
|
||||||
|
)
|
||||||
|
if "raven" in model_path:
|
||||||
|
return default_tokens_path
|
||||||
|
elif "world" in model_path:
|
||||||
|
return "rwkv_vocab_v20230424"
|
||||||
|
else:
|
||||||
|
return default_tokens_path
|
||||||
|
|
||||||
|
|
||||||
class SwitchModelBody(BaseModel):
|
class SwitchModelBody(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
strategy: str
|
strategy: str
|
||||||
@ -36,7 +49,7 @@ def switch_model(body: SwitchModelBody, response: Response):
|
|||||||
RWKV(
|
RWKV(
|
||||||
model=body.model,
|
model=body.model,
|
||||||
strategy=body.strategy,
|
strategy=body.strategy,
|
||||||
tokens_path=f"{pathlib.Path(__file__).parent.parent.resolve()}/20B_tokenizer.json",
|
tokens_path=get_tokens_path(body.model),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
106
backend-python/rwkv_pip/rwkv_tokenizer.py
Normal file
106
backend-python/rwkv_pip/rwkv_tokenizer.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class TRIE:
|
||||||
|
__slots__ = tuple("ch,to,values,front".split(","))
|
||||||
|
to: list
|
||||||
|
values: set
|
||||||
|
|
||||||
|
def __init__(self, front=None, ch=None):
|
||||||
|
self.ch = ch
|
||||||
|
self.to = [None for ch in range(256)]
|
||||||
|
self.values = set()
|
||||||
|
self.front = front
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
fr = self
|
||||||
|
ret = []
|
||||||
|
while fr != None:
|
||||||
|
if fr.ch != None:
|
||||||
|
ret.append(fr.ch)
|
||||||
|
fr = fr.front
|
||||||
|
return "<TRIE %s %s>" % (ret[::-1], self.values)
|
||||||
|
|
||||||
|
def add(self, key: bytes, idx: int = 0, val=None):
|
||||||
|
if idx == len(key):
|
||||||
|
if val is None:
|
||||||
|
val = key
|
||||||
|
self.values.add(val)
|
||||||
|
return self
|
||||||
|
ch = key[idx]
|
||||||
|
if self.to[ch] is None:
|
||||||
|
self.to[ch] = TRIE(front=self, ch=ch)
|
||||||
|
return self.to[ch].add(key, idx=idx + 1, val=val)
|
||||||
|
|
||||||
|
def find_longest(self, key: bytes, idx: int = 0):
|
||||||
|
u: TRIE = self
|
||||||
|
ch: int = key[idx]
|
||||||
|
|
||||||
|
while u.to[ch] is not None:
|
||||||
|
u = u.to[ch]
|
||||||
|
idx += 1
|
||||||
|
if u.values:
|
||||||
|
ret = idx, u, u.values
|
||||||
|
if idx == len(key):
|
||||||
|
break
|
||||||
|
ch = key[idx]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class TRIE_TOKENIZER:
|
||||||
|
def __init__(self, file_name):
|
||||||
|
self.idx2token = {}
|
||||||
|
sorted = [] # must be already sorted
|
||||||
|
with open(file_name, "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for l in lines:
|
||||||
|
idx = int(l[: l.index(" ")])
|
||||||
|
x = eval(l[l.index(" ") : l.rindex(" ")])
|
||||||
|
x = x.encode("utf-8") if isinstance(x, str) else x
|
||||||
|
assert isinstance(x, bytes)
|
||||||
|
assert len(x) == int(l[l.rindex(" ") :])
|
||||||
|
sorted += [x]
|
||||||
|
self.idx2token[idx] = x
|
||||||
|
|
||||||
|
self.token2idx = {}
|
||||||
|
for k, v in self.idx2token.items():
|
||||||
|
self.token2idx[v] = int(k)
|
||||||
|
|
||||||
|
self.root = TRIE()
|
||||||
|
for t, i in self.token2idx.items():
|
||||||
|
_ = self.root.add(t, val=(t, i))
|
||||||
|
|
||||||
|
def encodeBytes(self, src: bytes) -> list[int]:
|
||||||
|
idx: int = 0
|
||||||
|
tokens: list[int] = []
|
||||||
|
while idx < len(src):
|
||||||
|
_idx: int = idx
|
||||||
|
idx, _, values = self.root.find_longest(src, idx)
|
||||||
|
assert idx != _idx
|
||||||
|
_, token = next(iter(values))
|
||||||
|
tokens.append(token)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def decodeBytes(self, tokens):
|
||||||
|
return b"".join(map(lambda i: self.idx2token[i], tokens))
|
||||||
|
|
||||||
|
def encode(self, src):
|
||||||
|
return self.encodeBytes(src.encode("utf-8"))
|
||||||
|
|
||||||
|
def decode(self, tokens):
|
||||||
|
try:
|
||||||
|
return self.decodeBytes(tokens).decode("utf-8")
|
||||||
|
except:
|
||||||
|
return "\ufffd" # bad utf-8
|
||||||
|
|
||||||
|
def printTokens(self, tokens):
|
||||||
|
for i in tokens:
|
||||||
|
s = self.idx2token[i]
|
||||||
|
try:
|
||||||
|
s = s.decode("utf-8")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
print(f"{repr(s)}{i}", end=" ")
|
||||||
|
print()
|
65529
backend-python/rwkv_pip/rwkv_vocab_v20230424.txt
Normal file
65529
backend-python/rwkv_pip/rwkv_vocab_v20230424.txt
Normal file
File diff suppressed because it is too large
Load Diff
142
backend-python/rwkv_pip/utils.py
Normal file
142
backend-python/rwkv_pip/utils.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
########################################################################################################
|
||||||
|
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||||
|
########################################################################################################
|
||||||
|
|
||||||
|
import os, sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class PIPELINE_ARGS:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature=1.0,
|
||||||
|
top_p=0.85,
|
||||||
|
top_k=0,
|
||||||
|
alpha_frequency=0.2,
|
||||||
|
alpha_presence=0.2,
|
||||||
|
token_ban=[],
|
||||||
|
token_stop=[],
|
||||||
|
chunk_len=256,
|
||||||
|
):
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_p = top_p
|
||||||
|
self.top_k = top_k
|
||||||
|
self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
|
||||||
|
self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
|
||||||
|
self.token_ban = token_ban # ban the generation of some tokens
|
||||||
|
self.token_stop = token_stop # stop generation whenever you see any token here
|
||||||
|
self.chunk_len = (
|
||||||
|
chunk_len # split input into chunks to save VRAM (shorter -> slower)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PIPELINE:
|
||||||
|
def __init__(self, model, WORD_NAME):
|
||||||
|
self.model = model
|
||||||
|
if WORD_NAME == "cl100k_base":
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
self.tokenizer = tiktoken.get_encoding(WORD_NAME)
|
||||||
|
elif WORD_NAME == "rwkv_vocab_v20230424":
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
from rwkv_tokenizer import TRIE_TOKENIZER
|
||||||
|
|
||||||
|
self.tokenizer = TRIE_TOKENIZER(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
|
self.tokenizer = Tokenizer.from_file(WORD_NAME)
|
||||||
|
|
||||||
|
def refine_context(self, context):
|
||||||
|
context = context.strip().split("\n")
|
||||||
|
for c in range(len(context)):
|
||||||
|
context[c] = context[c].strip().strip("\u3000").strip("\r")
|
||||||
|
context = list(filter(lambda c: c != "", context))
|
||||||
|
context = "\n" + ("\n".join(context)).strip()
|
||||||
|
if context == "":
|
||||||
|
context = "\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
if "Tokenizer" in str(type(self.tokenizer)):
|
||||||
|
return self.tokenizer.encode(x).ids
|
||||||
|
else:
|
||||||
|
return self.tokenizer.encode(x)
|
||||||
|
|
||||||
|
def decode(self, x):
|
||||||
|
return self.tokenizer.decode(x)
|
||||||
|
|
||||||
|
def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
|
||||||
|
probs = F.softmax(logits.float(), dim=-1)
|
||||||
|
top_k = int(top_k)
|
||||||
|
if probs.device == torch.device("cpu"):
|
||||||
|
probs = probs.numpy()
|
||||||
|
sorted_ids = np.argsort(probs)
|
||||||
|
sorted_probs = probs[sorted_ids][::-1]
|
||||||
|
cumulative_probs = np.cumsum(sorted_probs)
|
||||||
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||||
|
probs[probs < cutoff] = 0
|
||||||
|
if top_k < len(probs) and top_k > 0:
|
||||||
|
probs[sorted_ids[:-top_k]] = 0
|
||||||
|
if temperature != 1.0:
|
||||||
|
probs = probs ** (1.0 / temperature)
|
||||||
|
probs = probs / np.sum(probs)
|
||||||
|
out = np.random.choice(a=len(probs), p=probs)
|
||||||
|
return int(out)
|
||||||
|
else:
|
||||||
|
sorted_ids = torch.argsort(probs)
|
||||||
|
sorted_probs = probs[sorted_ids]
|
||||||
|
sorted_probs = torch.flip(sorted_probs, dims=(0,))
|
||||||
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
|
||||||
|
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||||
|
probs[probs < cutoff] = 0
|
||||||
|
if top_k < len(probs) and top_k > 0:
|
||||||
|
probs[sorted_ids[:-top_k]] = 0
|
||||||
|
if temperature != 1.0:
|
||||||
|
probs = probs ** (1.0 / temperature)
|
||||||
|
out = torch.multinomial(probs, num_samples=1)[0]
|
||||||
|
return int(out)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None
|
||||||
|
):
|
||||||
|
all_tokens = []
|
||||||
|
out_last = 0
|
||||||
|
out_str = ""
|
||||||
|
occurrence = {}
|
||||||
|
for i in range(token_count):
|
||||||
|
# forward & adjust prob.
|
||||||
|
tokens = self.encode(ctx) if i == 0 else [token]
|
||||||
|
while len(tokens) > 0:
|
||||||
|
out, state = self.model.forward(tokens[: args.chunk_len], state)
|
||||||
|
tokens = tokens[args.chunk_len :]
|
||||||
|
|
||||||
|
for n in args.token_ban:
|
||||||
|
out[n] = -float("inf")
|
||||||
|
for n in occurrence:
|
||||||
|
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
||||||
|
|
||||||
|
# sampler
|
||||||
|
token = self.sample_logits(
|
||||||
|
out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k
|
||||||
|
)
|
||||||
|
if token in args.token_stop:
|
||||||
|
break
|
||||||
|
all_tokens += [token]
|
||||||
|
if token not in occurrence:
|
||||||
|
occurrence[token] = 1
|
||||||
|
else:
|
||||||
|
occurrence[token] += 1
|
||||||
|
|
||||||
|
# output
|
||||||
|
tmp = self.decode(all_tokens[out_last:])
|
||||||
|
if "\ufffd" not in tmp: # is valid utf-8 string?
|
||||||
|
if callback:
|
||||||
|
callback(tmp)
|
||||||
|
out_str += tmp
|
||||||
|
out_last = i + 1
|
||||||
|
return out_str
|
@ -1,8 +1,103 @@
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Dict
|
from typing import Dict, List
|
||||||
from langchain.llms import RWKV
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from rwkv_pip.utils import PIPELINE
|
||||||
|
|
||||||
|
|
||||||
|
END_OF_TEXT = 0
|
||||||
|
END_OF_LINE = 187
|
||||||
|
|
||||||
|
|
||||||
|
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
||||||
|
|
||||||
|
|
||||||
|
class RWKV:
|
||||||
|
def __init__(self, model: str, strategy: str, tokens_path: str) -> None:
|
||||||
|
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
|
||||||
|
|
||||||
|
self.model = Model(model, strategy)
|
||||||
|
self.pipeline = PIPELINE(self.model, tokens_path)
|
||||||
|
self.model_state = None
|
||||||
|
self.model_tokens = []
|
||||||
|
|
||||||
|
self.CHUNK_LEN = 256
|
||||||
|
|
||||||
|
self.max_tokens_per_generation = 500
|
||||||
|
self.temperature = 1
|
||||||
|
self.top_p = 0.5
|
||||||
|
self.penalty_alpha_presence = 0.4
|
||||||
|
self.penalty_alpha_frequency = 0.4
|
||||||
|
|
||||||
|
self.interface = ":"
|
||||||
|
if "rwkv_vocab" in tokens_path:
|
||||||
|
self.user = "Human"
|
||||||
|
self.bot = "Bot"
|
||||||
|
else:
|
||||||
|
self.user = "Bob"
|
||||||
|
self.bot = "Alice"
|
||||||
|
|
||||||
|
self.AVOID_REPEAT_TOKENS = []
|
||||||
|
AVOID_REPEAT = ",:?!"
|
||||||
|
for i in AVOID_REPEAT:
|
||||||
|
dd = self.pipeline.encode(i)
|
||||||
|
assert len(dd) == 1
|
||||||
|
self.AVOID_REPEAT_TOKENS += dd
|
||||||
|
|
||||||
|
def run_rnn(self, _tokens: List[str], newline_adj: int = 0):
|
||||||
|
tokens = [int(x) for x in _tokens]
|
||||||
|
self.model_tokens += tokens
|
||||||
|
|
||||||
|
while len(tokens) > 0:
|
||||||
|
out, self.model_state = self.model.forward(
|
||||||
|
tokens[: self.CHUNK_LEN], self.model_state
|
||||||
|
)
|
||||||
|
tokens = tokens[self.CHUNK_LEN :]
|
||||||
|
|
||||||
|
out[END_OF_LINE] += newline_adj # adjust \n probability
|
||||||
|
|
||||||
|
if self.model_tokens[-1] in self.AVOID_REPEAT_TOKENS:
|
||||||
|
out[self.model_tokens[-1]] = -999999999
|
||||||
|
return out
|
||||||
|
|
||||||
|
def generate(self, prompt: str, stop: str = None):
|
||||||
|
self.model_state = None
|
||||||
|
self.model_tokens = []
|
||||||
|
logits = self.run_rnn(self.pipeline.encode(prompt))
|
||||||
|
begin = len(self.model_tokens)
|
||||||
|
out_last = begin
|
||||||
|
|
||||||
|
occurrence: Dict = {}
|
||||||
|
|
||||||
|
response = ""
|
||||||
|
for i in range(self.max_tokens_per_generation):
|
||||||
|
for n in occurrence:
|
||||||
|
logits[n] -= (
|
||||||
|
self.penalty_alpha_presence
|
||||||
|
+ occurrence[n] * self.penalty_alpha_frequency
|
||||||
|
)
|
||||||
|
token = self.pipeline.sample_logits(
|
||||||
|
logits, temperature=self.temperature, top_p=self.top_p
|
||||||
|
)
|
||||||
|
|
||||||
|
if token == END_OF_TEXT:
|
||||||
|
break
|
||||||
|
if token not in occurrence:
|
||||||
|
occurrence[token] = 1
|
||||||
|
else:
|
||||||
|
occurrence[token] += 1
|
||||||
|
|
||||||
|
logits = self.run_rnn([token])
|
||||||
|
delta: str = self.pipeline.decode(self.model_tokens[out_last:])
|
||||||
|
if "\ufffd" not in delta: # avoid utf-8 display issues
|
||||||
|
response += delta
|
||||||
|
if stop is not None:
|
||||||
|
if stop in response:
|
||||||
|
response = response.split(stop)[0]
|
||||||
|
yield response, ""
|
||||||
|
break
|
||||||
|
out_last = begin + i + 1
|
||||||
|
yield response, delta
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigBody(BaseModel):
|
class ModelConfigBody(BaseModel):
|
||||||
@ -34,47 +129,3 @@ def get_rwkv_config(model: RWKV) -> ModelConfigBody:
|
|||||||
presence_penalty=model.penalty_alpha_presence,
|
presence_penalty=model.penalty_alpha_presence,
|
||||||
frequency_penalty=model.penalty_alpha_frequency,
|
frequency_penalty=model.penalty_alpha_frequency,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
|
||||||
|
|
||||||
|
|
||||||
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
|
|
||||||
model.model_state = None
|
|
||||||
model.model_tokens = []
|
|
||||||
logits = model.run_rnn(model.tokenizer.encode(prompt).ids)
|
|
||||||
begin = len(model.model_tokens)
|
|
||||||
out_last = begin
|
|
||||||
|
|
||||||
occurrence: Dict = {}
|
|
||||||
|
|
||||||
response = ""
|
|
||||||
for i in range(model.max_tokens_per_generation):
|
|
||||||
for n in occurrence:
|
|
||||||
logits[n] -= (
|
|
||||||
model.penalty_alpha_presence
|
|
||||||
+ occurrence[n] * model.penalty_alpha_frequency
|
|
||||||
)
|
|
||||||
token = model.pipeline.sample_logits(
|
|
||||||
logits, temperature=model.temperature, top_p=model.top_p
|
|
||||||
)
|
|
||||||
|
|
||||||
END_OF_TEXT = 0
|
|
||||||
if token == END_OF_TEXT:
|
|
||||||
break
|
|
||||||
if token not in occurrence:
|
|
||||||
occurrence[token] = 1
|
|
||||||
else:
|
|
||||||
occurrence[token] += 1
|
|
||||||
|
|
||||||
logits = model.run_rnn([token])
|
|
||||||
delta: str = model.tokenizer.decode(model.model_tokens[out_last:])
|
|
||||||
if "\ufffd" not in delta: # avoid utf-8 display issues
|
|
||||||
response += delta
|
|
||||||
if stop is not None:
|
|
||||||
if stop in response:
|
|
||||||
response = response.split(stop)[0]
|
|
||||||
yield response, ""
|
|
||||||
break
|
|
||||||
yield response, delta
|
|
||||||
out_last = begin + i + 1
|
|
||||||
|
@ -54,6 +54,14 @@
|
|||||||
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/utils/torch.py",
|
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/utils/torch.py",
|
||||||
"path": "backend-python/utils/torch.py"
|
"path": "backend-python/utils/torch.py"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/rwkv_pip/rwkv_tokenizer.py",
|
||||||
|
"path": "backend-python/rwkv_pip/rwkv_tokenizer.py"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/rwkv_pip/utils.py",
|
||||||
|
"path": "backend-python/rwkv_pip/utils.py"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd",
|
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd",
|
||||||
"path": "backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd"
|
"path": "backend-python/wkv_cuda_utils/wkv_cuda10_30.pyd"
|
||||||
@ -67,8 +75,12 @@
|
|||||||
"path": "backend-python/wkv_cuda_utils/wkv_cuda_model.py"
|
"path": "backend-python/wkv_cuda_utils/wkv_cuda_model.py"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/20B_tokenizer.json",
|
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/rwkv_pip/rwkv_vocab_v20230424.txt",
|
||||||
"path": "backend-python/20B_tokenizer.json"
|
"path": "backend-python/rwkv_pip/rwkv_vocab_v20230424.txt"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"url": "https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/backend-python/rwkv_pip/20B_tokenizer.json",
|
||||||
|
"path": "backend-python/rwkv_pip/20B_tokenizer.json"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"url": "https://cdn.jsdelivr.net/gh/pypa/get-pip/public/get-pip.py",
|
"url": "https://cdn.jsdelivr.net/gh/pypa/get-pip/public/get-pip.py",
|
||||||
|
Loading…
Reference in New Issue
Block a user