diff --git a/backend-python/rwkv_pip/utils.py b/backend-python/rwkv_pip/utils.py index f492ec2..23bb8db 100644 --- a/backend-python/rwkv_pip/utils.py +++ b/backend-python/rwkv_pip/utils.py @@ -34,6 +34,25 @@ class PIPELINE_ARGS: ) +class ABC_TOKENIZER: + def __init__(self): + self.pad_token_id = 0 + self.bos_token_id = 2 + self.eos_token_id = 3 + + def encode(self, text): + ids = [ord(c) for c in text] + return ids + + def decode(self, ids): + txt = "".join( + chr(idx) if idx > self.eos_token_id else "" + for idx in ids + if idx != self.eos_token_id + ) + return txt + + class PIPELINE: def __init__(self, model, WORD_NAME: str): self.model = model @@ -48,6 +67,8 @@ class PIPELINE: self.tokenizer = TRIE_TOKENIZER( os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt" ) + elif WORD_NAME == "abc_tokenizer": + self.tokenizer = ABC_TOKENIZER() else: if WORD_NAME.endswith(".txt"): sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index 2317a20..df58448 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -11,11 +11,6 @@ from pydantic import BaseModel, Field from routes import state_cache import global_var - -END_OF_TEXT = 0 -END_OF_LINE_DOUBLE = 535 - - os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}" @@ -28,6 +23,8 @@ class RWKVType(Enum): class AbstractRWKV(ABC): def __init__(self, model, pipeline): + self.EOS_ID = 0 + self.name = "rwkv" self.model = model self.pipeline = pipeline @@ -274,7 +271,7 @@ class AbstractRWKV(ABC): logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k ) - if token == END_OF_TEXT: + if token == self.EOS_ID: yield response, "", prompt_token_len, completion_token_len break @@ -401,7 +398,7 @@ class TextRWKV(AbstractRWKV): def fix_tokens(self, tokens) -> List[int]: if self.rwkv_type == RWKVType.World: return tokens - if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE: + if len(tokens) > 0 and tokens[-1] == 535: tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE] return tokens @@ -459,7 +456,7 @@ The following is a coherent verbose detailed conversation between a girl named { pass -class MusicRWKV(AbstractRWKV): +class MusicMidiRWKV(AbstractRWKV): def __init__(self, model, pipeline): super().__init__(model, pipeline) @@ -501,8 +498,45 @@ class MusicRWKV(AbstractRWKV): return " " + delta +class MusicAbcRWKV(AbstractRWKV): + def __init__(self, model, pipeline): + super().__init__(model, pipeline) + + self.EOS_ID = 3 + + self.max_tokens_per_generation = 500 + self.temperature = 1 + self.top_p = 0.8 + self.top_k = 8 + + self.rwkv_type = RWKVType.Music + + def adjust_occurrence(self, occurrence: Dict, token: int): + pass + + def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int): + pass + + def fix_tokens(self, tokens) -> List[int]: + return tokens + + def run_rnn( + self, _tokens: List[str], newline_adj: int = 0 + ) -> Tuple[List[float], int]: + tokens = [int(x) for x in _tokens] + token_len = len(tokens) + self.model_tokens += tokens + out, self.model_state = self.model.forward(tokens, self.model_state) + return out, token_len + + def delta_postprocess(self, delta: str) -> str: + return delta + + def get_tokenizer(tokenizer_len: int): tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/" + if tokenizer_len < 20096: + return "abc_tokenizer" if tokenizer_len < 50277: return tokenizer_dir + "tokenizer-midi.json" elif tokenizer_len < 65536: @@ -550,7 +584,8 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV rwkv_map: dict[str, Type[AbstractRWKV]] = { "20B_tokenizer": TextRWKV, "rwkv_vocab_v20230424": TextRWKV, - "tokenizer-midi": MusicRWKV, + "tokenizer-midi": MusicMidiRWKV, + "abc_tokenizer": MusicAbcRWKV, } tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0] rwkv: AbstractRWKV