abc music inference support
This commit is contained in:
		
							parent
							
								
									ff330a5487
								
							
						
					
					
						commit
						16079a3cba
					
				
							
								
								
									
										21
									
								
								backend-python/rwkv_pip/utils.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								backend-python/rwkv_pip/utils.py
									
									
									
									
										vendored
									
									
								
							@ -34,6 +34,25 @@ class PIPELINE_ARGS:
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ABC_TOKENIZER:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.pad_token_id = 0
 | 
			
		||||
        self.bos_token_id = 2
 | 
			
		||||
        self.eos_token_id = 3
 | 
			
		||||
 | 
			
		||||
    def encode(self, text):
 | 
			
		||||
        ids = [ord(c) for c in text]
 | 
			
		||||
        return ids
 | 
			
		||||
 | 
			
		||||
    def decode(self, ids):
 | 
			
		||||
        txt = "".join(
 | 
			
		||||
            chr(idx) if idx > self.eos_token_id else ""
 | 
			
		||||
            for idx in ids
 | 
			
		||||
            if idx != self.eos_token_id
 | 
			
		||||
        )
 | 
			
		||||
        return txt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PIPELINE:
 | 
			
		||||
    def __init__(self, model, WORD_NAME: str):
 | 
			
		||||
        self.model = model
 | 
			
		||||
@ -48,6 +67,8 @@ class PIPELINE:
 | 
			
		||||
            self.tokenizer = TRIE_TOKENIZER(
 | 
			
		||||
                os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
 | 
			
		||||
            )
 | 
			
		||||
        elif WORD_NAME == "abc_tokenizer":
 | 
			
		||||
            self.tokenizer = ABC_TOKENIZER()
 | 
			
		||||
        else:
 | 
			
		||||
            if WORD_NAME.endswith(".txt"):
 | 
			
		||||
                sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 | 
			
		||||
 | 
			
		||||
@ -11,11 +11,6 @@ from pydantic import BaseModel, Field
 | 
			
		||||
from routes import state_cache
 | 
			
		||||
import global_var
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
END_OF_TEXT = 0
 | 
			
		||||
END_OF_LINE_DOUBLE = 535
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -28,6 +23,8 @@ class RWKVType(Enum):
 | 
			
		||||
 | 
			
		||||
class AbstractRWKV(ABC):
 | 
			
		||||
    def __init__(self, model, pipeline):
 | 
			
		||||
        self.EOS_ID = 0
 | 
			
		||||
 | 
			
		||||
        self.name = "rwkv"
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.pipeline = pipeline
 | 
			
		||||
@ -274,7 +271,7 @@ class AbstractRWKV(ABC):
 | 
			
		||||
                logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if token == END_OF_TEXT:
 | 
			
		||||
            if token == self.EOS_ID:
 | 
			
		||||
                yield response, "", prompt_token_len, completion_token_len
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
@ -401,7 +398,7 @@ class TextRWKV(AbstractRWKV):
 | 
			
		||||
    def fix_tokens(self, tokens) -> List[int]:
 | 
			
		||||
        if self.rwkv_type == RWKVType.World:
 | 
			
		||||
            return tokens
 | 
			
		||||
        if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
 | 
			
		||||
        if len(tokens) > 0 and tokens[-1] == 535:
 | 
			
		||||
            tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
 | 
			
		||||
        return tokens
 | 
			
		||||
 | 
			
		||||
@ -459,7 +456,7 @@ The following is a coherent verbose detailed conversation between a girl named {
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MusicRWKV(AbstractRWKV):
 | 
			
		||||
class MusicMidiRWKV(AbstractRWKV):
 | 
			
		||||
    def __init__(self, model, pipeline):
 | 
			
		||||
        super().__init__(model, pipeline)
 | 
			
		||||
 | 
			
		||||
@ -501,8 +498,45 @@ class MusicRWKV(AbstractRWKV):
 | 
			
		||||
        return " " + delta
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MusicAbcRWKV(AbstractRWKV):
 | 
			
		||||
    def __init__(self, model, pipeline):
 | 
			
		||||
        super().__init__(model, pipeline)
 | 
			
		||||
 | 
			
		||||
        self.EOS_ID = 3
 | 
			
		||||
 | 
			
		||||
        self.max_tokens_per_generation = 500
 | 
			
		||||
        self.temperature = 1
 | 
			
		||||
        self.top_p = 0.8
 | 
			
		||||
        self.top_k = 8
 | 
			
		||||
 | 
			
		||||
        self.rwkv_type = RWKVType.Music
 | 
			
		||||
 | 
			
		||||
    def adjust_occurrence(self, occurrence: Dict, token: int):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def fix_tokens(self, tokens) -> List[int]:
 | 
			
		||||
        return tokens
 | 
			
		||||
 | 
			
		||||
    def run_rnn(
 | 
			
		||||
        self, _tokens: List[str], newline_adj: int = 0
 | 
			
		||||
    ) -> Tuple[List[float], int]:
 | 
			
		||||
        tokens = [int(x) for x in _tokens]
 | 
			
		||||
        token_len = len(tokens)
 | 
			
		||||
        self.model_tokens += tokens
 | 
			
		||||
        out, self.model_state = self.model.forward(tokens, self.model_state)
 | 
			
		||||
        return out, token_len
 | 
			
		||||
 | 
			
		||||
    def delta_postprocess(self, delta: str) -> str:
 | 
			
		||||
        return delta
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_tokenizer(tokenizer_len: int):
 | 
			
		||||
    tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
 | 
			
		||||
    if tokenizer_len < 20096:
 | 
			
		||||
        return "abc_tokenizer"
 | 
			
		||||
    if tokenizer_len < 50277:
 | 
			
		||||
        return tokenizer_dir + "tokenizer-midi.json"
 | 
			
		||||
    elif tokenizer_len < 65536:
 | 
			
		||||
@ -550,7 +584,8 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
 | 
			
		||||
    rwkv_map: dict[str, Type[AbstractRWKV]] = {
 | 
			
		||||
        "20B_tokenizer": TextRWKV,
 | 
			
		||||
        "rwkv_vocab_v20230424": TextRWKV,
 | 
			
		||||
        "tokenizer-midi": MusicRWKV,
 | 
			
		||||
        "tokenizer-midi": MusicMidiRWKV,
 | 
			
		||||
        "abc_tokenizer": MusicAbcRWKV,
 | 
			
		||||
    }
 | 
			
		||||
    tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
 | 
			
		||||
    rwkv: AbstractRWKV
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user