This commit is contained in:
josc146 2023-07-31 22:46:13 +08:00
parent d12a173f39
commit 8764c37b03

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum, auto
import os import os
import pathlib import pathlib
import copy import copy
@ -18,6 +19,12 @@ END_OF_LINE_DOUBLE = 535
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}" os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
class RWKVType(Enum):
Raven = auto()
World = auto()
Music = auto()
class AbstractRWKV(ABC): class AbstractRWKV(ABC):
def __init__(self, model: str, strategy: str, tokens_path: str): def __init__(self, model: str, strategy: str, tokens_path: str):
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
@ -29,6 +36,7 @@ class AbstractRWKV(ABC):
self.pipeline = PIPELINE(self.model, tokens_path) self.pipeline = PIPELINE(self.model, tokens_path)
self.model_state = None self.model_state = None
self.model_tokens = [] self.model_tokens = []
self.rwkv_type: RWKVType = None
self.max_tokens_per_generation = 500 self.max_tokens_per_generation = 500
self.temperature = 1 self.temperature = 1
@ -343,10 +351,12 @@ class TextRWKV(AbstractRWKV):
self.interface = ":" self.interface = ":"
if "world" in self.name.lower(): if "world" in self.name.lower():
self.rwkv_type = RWKVType.World
self.user = "Question" self.user = "Question"
self.bot = "Answer" self.bot = "Answer"
self.END_OF_LINE = 11 self.END_OF_LINE = 11
else: else:
self.rwkv_type = RWKVType.Raven
self.user = "Bob" self.user = "Bob"
self.bot = "Alice" self.bot = "Alice"
self.END_OF_LINE = 187 self.END_OF_LINE = 187
@ -387,7 +397,7 @@ class TextRWKV(AbstractRWKV):
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
def fix_tokens(self, tokens) -> List[int]: def fix_tokens(self, tokens) -> List[int]:
if "world" in self.name.lower(): if self.rwkv_type == RWKVType.World:
return tokens return tokens
if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE: if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE] tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
@ -427,7 +437,7 @@ The following is a coherent verbose detailed conversation between a girl named {
{bot} likes to tell {user} a lot about herself and her opinions. \ {bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n {bot} usually gives {user} kind, helpful and informative advices.\n
""" """
if self.user == "Bob" if self.rwkv_type == RWKVType.Raven
else f"{user}{interface} hi\n\n{bot}{interface} Hi. " else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n" + "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
) )
@ -454,6 +464,8 @@ class MusicRWKV(AbstractRWKV):
self.top_p = 0.8 self.top_p = 0.8
self.top_k = 8 self.top_k = 8
self.rwkv_type = RWKVType.Music
def adjust_occurrence(self, occurrence: Dict, token: int): def adjust_occurrence(self, occurrence: Dict, token: int):
for n in occurrence: for n in occurrence:
occurrence[n] *= 0.997 #### decay repetition penalty occurrence[n] *= 0.997 #### decay repetition penalty