abc music inference support
This commit is contained in:
parent
ff330a5487
commit
16079a3cba
21
backend-python/rwkv_pip/utils.py
vendored
21
backend-python/rwkv_pip/utils.py
vendored
@ -34,6 +34,25 @@ class PIPELINE_ARGS:
|
||||
)
|
||||
|
||||
|
||||
class ABC_TOKENIZER:
|
||||
def __init__(self):
|
||||
self.pad_token_id = 0
|
||||
self.bos_token_id = 2
|
||||
self.eos_token_id = 3
|
||||
|
||||
def encode(self, text):
|
||||
ids = [ord(c) for c in text]
|
||||
return ids
|
||||
|
||||
def decode(self, ids):
|
||||
txt = "".join(
|
||||
chr(idx) if idx > self.eos_token_id else ""
|
||||
for idx in ids
|
||||
if idx != self.eos_token_id
|
||||
)
|
||||
return txt
|
||||
|
||||
|
||||
class PIPELINE:
|
||||
def __init__(self, model, WORD_NAME: str):
|
||||
self.model = model
|
||||
@ -48,6 +67,8 @@ class PIPELINE:
|
||||
self.tokenizer = TRIE_TOKENIZER(
|
||||
os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
|
||||
)
|
||||
elif WORD_NAME == "abc_tokenizer":
|
||||
self.tokenizer = ABC_TOKENIZER()
|
||||
else:
|
||||
if WORD_NAME.endswith(".txt"):
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
@ -11,11 +11,6 @@ from pydantic import BaseModel, Field
|
||||
from routes import state_cache
|
||||
import global_var
|
||||
|
||||
|
||||
END_OF_TEXT = 0
|
||||
END_OF_LINE_DOUBLE = 535
|
||||
|
||||
|
||||
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
||||
|
||||
|
||||
@ -28,6 +23,8 @@ class RWKVType(Enum):
|
||||
|
||||
class AbstractRWKV(ABC):
|
||||
def __init__(self, model, pipeline):
|
||||
self.EOS_ID = 0
|
||||
|
||||
self.name = "rwkv"
|
||||
self.model = model
|
||||
self.pipeline = pipeline
|
||||
@ -274,7 +271,7 @@ class AbstractRWKV(ABC):
|
||||
logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
|
||||
)
|
||||
|
||||
if token == END_OF_TEXT:
|
||||
if token == self.EOS_ID:
|
||||
yield response, "", prompt_token_len, completion_token_len
|
||||
break
|
||||
|
||||
@ -401,7 +398,7 @@ class TextRWKV(AbstractRWKV):
|
||||
def fix_tokens(self, tokens) -> List[int]:
|
||||
if self.rwkv_type == RWKVType.World:
|
||||
return tokens
|
||||
if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
|
||||
if len(tokens) > 0 and tokens[-1] == 535:
|
||||
tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
|
||||
return tokens
|
||||
|
||||
@ -459,7 +456,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
pass
|
||||
|
||||
|
||||
class MusicRWKV(AbstractRWKV):
|
||||
class MusicMidiRWKV(AbstractRWKV):
|
||||
def __init__(self, model, pipeline):
|
||||
super().__init__(model, pipeline)
|
||||
|
||||
@ -501,8 +498,45 @@ class MusicRWKV(AbstractRWKV):
|
||||
return " " + delta
|
||||
|
||||
|
||||
class MusicAbcRWKV(AbstractRWKV):
|
||||
def __init__(self, model, pipeline):
|
||||
super().__init__(model, pipeline)
|
||||
|
||||
self.EOS_ID = 3
|
||||
|
||||
self.max_tokens_per_generation = 500
|
||||
self.temperature = 1
|
||||
self.top_p = 0.8
|
||||
self.top_k = 8
|
||||
|
||||
self.rwkv_type = RWKVType.Music
|
||||
|
||||
def adjust_occurrence(self, occurrence: Dict, token: int):
|
||||
pass
|
||||
|
||||
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
|
||||
pass
|
||||
|
||||
def fix_tokens(self, tokens) -> List[int]:
|
||||
return tokens
|
||||
|
||||
def run_rnn(
|
||||
self, _tokens: List[str], newline_adj: int = 0
|
||||
) -> Tuple[List[float], int]:
|
||||
tokens = [int(x) for x in _tokens]
|
||||
token_len = len(tokens)
|
||||
self.model_tokens += tokens
|
||||
out, self.model_state = self.model.forward(tokens, self.model_state)
|
||||
return out, token_len
|
||||
|
||||
def delta_postprocess(self, delta: str) -> str:
|
||||
return delta
|
||||
|
||||
|
||||
def get_tokenizer(tokenizer_len: int):
|
||||
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
|
||||
if tokenizer_len < 20096:
|
||||
return "abc_tokenizer"
|
||||
if tokenizer_len < 50277:
|
||||
return tokenizer_dir + "tokenizer-midi.json"
|
||||
elif tokenizer_len < 65536:
|
||||
@ -550,7 +584,8 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
|
||||
rwkv_map: dict[str, Type[AbstractRWKV]] = {
|
||||
"20B_tokenizer": TextRWKV,
|
||||
"rwkv_vocab_v20230424": TextRWKV,
|
||||
"tokenizer-midi": MusicRWKV,
|
||||
"tokenizer-midi": MusicMidiRWKV,
|
||||
"abc_tokenizer": MusicAbcRWKV,
|
||||
}
|
||||
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
|
||||
rwkv: AbstractRWKV
|
||||
|
Loading…
Reference in New Issue
Block a user