abc music inference support
This commit is contained in:
parent
ff330a5487
commit
16079a3cba
21
backend-python/rwkv_pip/utils.py
vendored
21
backend-python/rwkv_pip/utils.py
vendored
@ -34,6 +34,25 @@ class PIPELINE_ARGS:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ABC_TOKENIZER:
|
||||||
|
def __init__(self):
|
||||||
|
self.pad_token_id = 0
|
||||||
|
self.bos_token_id = 2
|
||||||
|
self.eos_token_id = 3
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
ids = [ord(c) for c in text]
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def decode(self, ids):
|
||||||
|
txt = "".join(
|
||||||
|
chr(idx) if idx > self.eos_token_id else ""
|
||||||
|
for idx in ids
|
||||||
|
if idx != self.eos_token_id
|
||||||
|
)
|
||||||
|
return txt
|
||||||
|
|
||||||
|
|
||||||
class PIPELINE:
|
class PIPELINE:
|
||||||
def __init__(self, model, WORD_NAME: str):
|
def __init__(self, model, WORD_NAME: str):
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -48,6 +67,8 @@ class PIPELINE:
|
|||||||
self.tokenizer = TRIE_TOKENIZER(
|
self.tokenizer = TRIE_TOKENIZER(
|
||||||
os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
|
os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
|
||||||
)
|
)
|
||||||
|
elif WORD_NAME == "abc_tokenizer":
|
||||||
|
self.tokenizer = ABC_TOKENIZER()
|
||||||
else:
|
else:
|
||||||
if WORD_NAME.endswith(".txt"):
|
if WORD_NAME.endswith(".txt"):
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
@ -11,11 +11,6 @@ from pydantic import BaseModel, Field
|
|||||||
from routes import state_cache
|
from routes import state_cache
|
||||||
import global_var
|
import global_var
|
||||||
|
|
||||||
|
|
||||||
END_OF_TEXT = 0
|
|
||||||
END_OF_LINE_DOUBLE = 535
|
|
||||||
|
|
||||||
|
|
||||||
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
||||||
|
|
||||||
|
|
||||||
@ -28,6 +23,8 @@ class RWKVType(Enum):
|
|||||||
|
|
||||||
class AbstractRWKV(ABC):
|
class AbstractRWKV(ABC):
|
||||||
def __init__(self, model, pipeline):
|
def __init__(self, model, pipeline):
|
||||||
|
self.EOS_ID = 0
|
||||||
|
|
||||||
self.name = "rwkv"
|
self.name = "rwkv"
|
||||||
self.model = model
|
self.model = model
|
||||||
self.pipeline = pipeline
|
self.pipeline = pipeline
|
||||||
@ -274,7 +271,7 @@ class AbstractRWKV(ABC):
|
|||||||
logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
|
logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
|
||||||
)
|
)
|
||||||
|
|
||||||
if token == END_OF_TEXT:
|
if token == self.EOS_ID:
|
||||||
yield response, "", prompt_token_len, completion_token_len
|
yield response, "", prompt_token_len, completion_token_len
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -401,7 +398,7 @@ class TextRWKV(AbstractRWKV):
|
|||||||
def fix_tokens(self, tokens) -> List[int]:
|
def fix_tokens(self, tokens) -> List[int]:
|
||||||
if self.rwkv_type == RWKVType.World:
|
if self.rwkv_type == RWKVType.World:
|
||||||
return tokens
|
return tokens
|
||||||
if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
|
if len(tokens) > 0 and tokens[-1] == 535:
|
||||||
tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
|
tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
@ -459,7 +456,7 @@ The following is a coherent verbose detailed conversation between a girl named {
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MusicRWKV(AbstractRWKV):
|
class MusicMidiRWKV(AbstractRWKV):
|
||||||
def __init__(self, model, pipeline):
|
def __init__(self, model, pipeline):
|
||||||
super().__init__(model, pipeline)
|
super().__init__(model, pipeline)
|
||||||
|
|
||||||
@ -501,8 +498,45 @@ class MusicRWKV(AbstractRWKV):
|
|||||||
return " " + delta
|
return " " + delta
|
||||||
|
|
||||||
|
|
||||||
|
class MusicAbcRWKV(AbstractRWKV):
|
||||||
|
def __init__(self, model, pipeline):
|
||||||
|
super().__init__(model, pipeline)
|
||||||
|
|
||||||
|
self.EOS_ID = 3
|
||||||
|
|
||||||
|
self.max_tokens_per_generation = 500
|
||||||
|
self.temperature = 1
|
||||||
|
self.top_p = 0.8
|
||||||
|
self.top_k = 8
|
||||||
|
|
||||||
|
self.rwkv_type = RWKVType.Music
|
||||||
|
|
||||||
|
def adjust_occurrence(self, occurrence: Dict, token: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fix_tokens(self, tokens) -> List[int]:
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def run_rnn(
|
||||||
|
self, _tokens: List[str], newline_adj: int = 0
|
||||||
|
) -> Tuple[List[float], int]:
|
||||||
|
tokens = [int(x) for x in _tokens]
|
||||||
|
token_len = len(tokens)
|
||||||
|
self.model_tokens += tokens
|
||||||
|
out, self.model_state = self.model.forward(tokens, self.model_state)
|
||||||
|
return out, token_len
|
||||||
|
|
||||||
|
def delta_postprocess(self, delta: str) -> str:
|
||||||
|
return delta
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(tokenizer_len: int):
|
def get_tokenizer(tokenizer_len: int):
|
||||||
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
|
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
|
||||||
|
if tokenizer_len < 20096:
|
||||||
|
return "abc_tokenizer"
|
||||||
if tokenizer_len < 50277:
|
if tokenizer_len < 50277:
|
||||||
return tokenizer_dir + "tokenizer-midi.json"
|
return tokenizer_dir + "tokenizer-midi.json"
|
||||||
elif tokenizer_len < 65536:
|
elif tokenizer_len < 65536:
|
||||||
@ -550,7 +584,8 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
|
|||||||
rwkv_map: dict[str, Type[AbstractRWKV]] = {
|
rwkv_map: dict[str, Type[AbstractRWKV]] = {
|
||||||
"20B_tokenizer": TextRWKV,
|
"20B_tokenizer": TextRWKV,
|
||||||
"rwkv_vocab_v20230424": TextRWKV,
|
"rwkv_vocab_v20230424": TextRWKV,
|
||||||
"tokenizer-midi": MusicRWKV,
|
"tokenizer-midi": MusicMidiRWKV,
|
||||||
|
"abc_tokenizer": MusicAbcRWKV,
|
||||||
}
|
}
|
||||||
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
|
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
|
||||||
rwkv: AbstractRWKV
|
rwkv: AbstractRWKV
|
||||||
|
Loading…
x
Reference in New Issue
Block a user