This commit is contained in:
josc146 2023-07-31 22:46:13 +08:00
parent d12a173f39
commit 8764c37b03

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from enum import Enum, auto
import os
import pathlib
import copy
@ -18,6 +19,12 @@ END_OF_LINE_DOUBLE = 535
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
class RWKVType(Enum):
Raven = auto()
World = auto()
Music = auto()
class AbstractRWKV(ABC):
def __init__(self, model: str, strategy: str, tokens_path: str):
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
@ -29,6 +36,7 @@ class AbstractRWKV(ABC):
self.pipeline = PIPELINE(self.model, tokens_path)
self.model_state = None
self.model_tokens = []
self.rwkv_type: RWKVType = None
self.max_tokens_per_generation = 500
self.temperature = 1
@ -343,10 +351,12 @@ class TextRWKV(AbstractRWKV):
self.interface = ":"
if "world" in self.name.lower():
self.rwkv_type = RWKVType.World
self.user = "Question"
self.bot = "Answer"
self.END_OF_LINE = 11
else:
self.rwkv_type = RWKVType.Raven
self.user = "Bob"
self.bot = "Alice"
self.END_OF_LINE = 187
@ -387,7 +397,7 @@ class TextRWKV(AbstractRWKV):
# Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end
def fix_tokens(self, tokens) -> List[int]:
if "world" in self.name.lower():
if self.rwkv_type == RWKVType.World:
return tokens
if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
@ -427,7 +437,7 @@ The following is a coherent verbose detailed conversation between a girl named {
{bot} likes to tell {user} a lot about herself and her opinions. \
{bot} usually gives {user} kind, helpful and informative advices.\n
"""
if self.user == "Bob"
if self.rwkv_type == RWKVType.Raven
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
)
@ -454,6 +464,8 @@ class MusicRWKV(AbstractRWKV):
self.top_p = 0.8
self.top_k = 8
self.rwkv_type = RWKVType.Music
def adjust_occurrence(self, occurrence: Dict, token: int):
for n in occurrence:
occurrence[n] *= 0.997 #### decay repetition penalty