From 8764c37b031ab791504765d52e0f58b34f175281 Mon Sep 17 00:00:00 2001 From: josc146 Date: Mon, 31 Jul 2023 22:46:13 +0800 Subject: [PATCH] RWKVType --- backend-python/utils/rwkv.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index 21ea4b4..b24e4e1 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from enum import Enum, auto import os import pathlib import copy @@ -18,6 +19,12 @@ END_OF_LINE_DOUBLE = 535 os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}" +class RWKVType(Enum): + Raven = auto() + World = auto() + Music = auto() + + class AbstractRWKV(ABC): def __init__(self, model: str, strategy: str, tokens_path: str): from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work @@ -29,6 +36,7 @@ class AbstractRWKV(ABC): self.pipeline = PIPELINE(self.model, tokens_path) self.model_state = None self.model_tokens = [] + self.rwkv_type: RWKVType = None self.max_tokens_per_generation = 500 self.temperature = 1 @@ -343,10 +351,12 @@ class TextRWKV(AbstractRWKV): self.interface = ":" if "world" in self.name.lower(): + self.rwkv_type = RWKVType.World self.user = "Question" self.bot = "Answer" self.END_OF_LINE = 11 else: + self.rwkv_type = RWKVType.Raven self.user = "Bob" self.bot = "Alice" self.END_OF_LINE = 187 @@ -387,7 +397,7 @@ class TextRWKV(AbstractRWKV): # Model only saw '\n\n' as [187, 187] before, but the tokenizer outputs [535] for it at the end def fix_tokens(self, tokens) -> List[int]: - if "world" in self.name.lower(): + if self.rwkv_type == RWKVType.World: return tokens if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE: tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE] @@ -427,7 +437,7 @@ The following is a coherent verbose detailed conversation between a girl named { {bot} likes to tell {user} a lot about herself and her opinions. \ {bot} usually gives {user} kind, helpful and informative advices.\n """ - if self.user == "Bob" + if self.rwkv_type == RWKVType.Raven else f"{user}{interface} hi\n\n{bot}{interface} Hi. " + "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n" ) @@ -454,6 +464,8 @@ class MusicRWKV(AbstractRWKV): self.top_p = 0.8 self.top_k = 8 + self.rwkv_type = RWKVType.Music + def adjust_occurrence(self, occurrence: Dict, token: int): for n in occurrence: occurrence[n] *= 0.997 #### decay repetition penalty