abc music inference support
This commit is contained in:
		
							parent
							
								
									ff330a5487
								
							
						
					
					
						commit
						16079a3cba
					
				
							
								
								
									
										21
									
								
								backend-python/rwkv_pip/utils.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								backend-python/rwkv_pip/utils.py
									
									
									
									
										vendored
									
									
								
							@ -34,6 +34,25 @@ class PIPELINE_ARGS:
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ABC_TOKENIZER:
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        self.pad_token_id = 0
 | 
				
			||||||
 | 
					        self.bos_token_id = 2
 | 
				
			||||||
 | 
					        self.eos_token_id = 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def encode(self, text):
 | 
				
			||||||
 | 
					        ids = [ord(c) for c in text]
 | 
				
			||||||
 | 
					        return ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decode(self, ids):
 | 
				
			||||||
 | 
					        txt = "".join(
 | 
				
			||||||
 | 
					            chr(idx) if idx > self.eos_token_id else ""
 | 
				
			||||||
 | 
					            for idx in ids
 | 
				
			||||||
 | 
					            if idx != self.eos_token_id
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PIPELINE:
 | 
					class PIPELINE:
 | 
				
			||||||
    def __init__(self, model, WORD_NAME: str):
 | 
					    def __init__(self, model, WORD_NAME: str):
 | 
				
			||||||
        self.model = model
 | 
					        self.model = model
 | 
				
			||||||
@ -48,6 +67,8 @@ class PIPELINE:
 | 
				
			|||||||
            self.tokenizer = TRIE_TOKENIZER(
 | 
					            self.tokenizer = TRIE_TOKENIZER(
 | 
				
			||||||
                os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
 | 
					                os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					        elif WORD_NAME == "abc_tokenizer":
 | 
				
			||||||
 | 
					            self.tokenizer = ABC_TOKENIZER()
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if WORD_NAME.endswith(".txt"):
 | 
					            if WORD_NAME.endswith(".txt"):
 | 
				
			||||||
                sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 | 
					                sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 | 
				
			||||||
 | 
				
			|||||||
@ -11,11 +11,6 @@ from pydantic import BaseModel, Field
 | 
				
			|||||||
from routes import state_cache
 | 
					from routes import state_cache
 | 
				
			||||||
import global_var
 | 
					import global_var
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
END_OF_TEXT = 0
 | 
					 | 
				
			||||||
END_OF_LINE_DOUBLE = 535
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
 | 
					os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -28,6 +23,8 @@ class RWKVType(Enum):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class AbstractRWKV(ABC):
 | 
					class AbstractRWKV(ABC):
 | 
				
			||||||
    def __init__(self, model, pipeline):
 | 
					    def __init__(self, model, pipeline):
 | 
				
			||||||
 | 
					        self.EOS_ID = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.name = "rwkv"
 | 
					        self.name = "rwkv"
 | 
				
			||||||
        self.model = model
 | 
					        self.model = model
 | 
				
			||||||
        self.pipeline = pipeline
 | 
					        self.pipeline = pipeline
 | 
				
			||||||
@ -274,7 +271,7 @@ class AbstractRWKV(ABC):
 | 
				
			|||||||
                logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
 | 
					                logits, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if token == END_OF_TEXT:
 | 
					            if token == self.EOS_ID:
 | 
				
			||||||
                yield response, "", prompt_token_len, completion_token_len
 | 
					                yield response, "", prompt_token_len, completion_token_len
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -401,7 +398,7 @@ class TextRWKV(AbstractRWKV):
 | 
				
			|||||||
    def fix_tokens(self, tokens) -> List[int]:
 | 
					    def fix_tokens(self, tokens) -> List[int]:
 | 
				
			||||||
        if self.rwkv_type == RWKVType.World:
 | 
					        if self.rwkv_type == RWKVType.World:
 | 
				
			||||||
            return tokens
 | 
					            return tokens
 | 
				
			||||||
        if len(tokens) > 0 and tokens[-1] == END_OF_LINE_DOUBLE:
 | 
					        if len(tokens) > 0 and tokens[-1] == 535:
 | 
				
			||||||
            tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
 | 
					            tokens = tokens[:-1] + [self.END_OF_LINE, self.END_OF_LINE]
 | 
				
			||||||
        return tokens
 | 
					        return tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -459,7 +456,7 @@ The following is a coherent verbose detailed conversation between a girl named {
 | 
				
			|||||||
            pass
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MusicRWKV(AbstractRWKV):
 | 
					class MusicMidiRWKV(AbstractRWKV):
 | 
				
			||||||
    def __init__(self, model, pipeline):
 | 
					    def __init__(self, model, pipeline):
 | 
				
			||||||
        super().__init__(model, pipeline)
 | 
					        super().__init__(model, pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -501,8 +498,45 @@ class MusicRWKV(AbstractRWKV):
 | 
				
			|||||||
        return " " + delta
 | 
					        return " " + delta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class MusicAbcRWKV(AbstractRWKV):
 | 
				
			||||||
 | 
					    def __init__(self, model, pipeline):
 | 
				
			||||||
 | 
					        super().__init__(model, pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.EOS_ID = 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.max_tokens_per_generation = 500
 | 
				
			||||||
 | 
					        self.temperature = 1
 | 
				
			||||||
 | 
					        self.top_p = 0.8
 | 
				
			||||||
 | 
					        self.top_k = 8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.rwkv_type = RWKVType.Music
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def adjust_occurrence(self, occurrence: Dict, token: int):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def adjust_forward_logits(self, logits: List[float], occurrence: Dict, i: int):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def fix_tokens(self, tokens) -> List[int]:
 | 
				
			||||||
 | 
					        return tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def run_rnn(
 | 
				
			||||||
 | 
					        self, _tokens: List[str], newline_adj: int = 0
 | 
				
			||||||
 | 
					    ) -> Tuple[List[float], int]:
 | 
				
			||||||
 | 
					        tokens = [int(x) for x in _tokens]
 | 
				
			||||||
 | 
					        token_len = len(tokens)
 | 
				
			||||||
 | 
					        self.model_tokens += tokens
 | 
				
			||||||
 | 
					        out, self.model_state = self.model.forward(tokens, self.model_state)
 | 
				
			||||||
 | 
					        return out, token_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def delta_postprocess(self, delta: str) -> str:
 | 
				
			||||||
 | 
					        return delta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_tokenizer(tokenizer_len: int):
 | 
					def get_tokenizer(tokenizer_len: int):
 | 
				
			||||||
    tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
 | 
					    tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
 | 
				
			||||||
 | 
					    if tokenizer_len < 20096:
 | 
				
			||||||
 | 
					        return "abc_tokenizer"
 | 
				
			||||||
    if tokenizer_len < 50277:
 | 
					    if tokenizer_len < 50277:
 | 
				
			||||||
        return tokenizer_dir + "tokenizer-midi.json"
 | 
					        return tokenizer_dir + "tokenizer-midi.json"
 | 
				
			||||||
    elif tokenizer_len < 65536:
 | 
					    elif tokenizer_len < 65536:
 | 
				
			||||||
@ -550,7 +584,8 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
 | 
				
			|||||||
    rwkv_map: dict[str, Type[AbstractRWKV]] = {
 | 
					    rwkv_map: dict[str, Type[AbstractRWKV]] = {
 | 
				
			||||||
        "20B_tokenizer": TextRWKV,
 | 
					        "20B_tokenizer": TextRWKV,
 | 
				
			||||||
        "rwkv_vocab_v20230424": TextRWKV,
 | 
					        "rwkv_vocab_v20230424": TextRWKV,
 | 
				
			||||||
        "tokenizer-midi": MusicRWKV,
 | 
					        "tokenizer-midi": MusicMidiRWKV,
 | 
				
			||||||
 | 
					        "abc_tokenizer": MusicAbcRWKV,
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
 | 
					    tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
 | 
				
			||||||
    rwkv: AbstractRWKV
 | 
					    rwkv: AbstractRWKV
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user