RWKVType now no longer relies on the file name
This commit is contained in:
		
							parent
							
								
									1d7f19ffaf
								
							
						
					
					
						commit
						627a20936d
					
				@ -10,22 +10,6 @@ import global_var
 | 
				
			|||||||
router = APIRouter()
 | 
					router = APIRouter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_tokens_path(model_path: str):
 | 
					 | 
				
			||||||
    model_path = model_path.lower()
 | 
					 | 
				
			||||||
    tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    default_tokens_path = tokenizer_dir + "20B_tokenizer.json"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if "raven" in model_path:
 | 
					 | 
				
			||||||
        return default_tokens_path
 | 
					 | 
				
			||||||
    elif "world" in model_path:
 | 
					 | 
				
			||||||
        return "rwkv_vocab_v20230424"
 | 
					 | 
				
			||||||
    elif "midi" in model_path:
 | 
					 | 
				
			||||||
        return tokenizer_dir + "tokenizer-midi.json"
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        return default_tokens_path
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SwitchModelBody(BaseModel):
 | 
					class SwitchModelBody(BaseModel):
 | 
				
			||||||
    model: str
 | 
					    model: str
 | 
				
			||||||
    strategy: str
 | 
					    strategy: str
 | 
				
			||||||
@ -67,25 +51,10 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
 | 
				
			|||||||
    os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
 | 
					    os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
 | 
					    global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
 | 
				
			||||||
    tokenizer = (
 | 
					 | 
				
			||||||
        get_tokens_path(body.model)
 | 
					 | 
				
			||||||
        if body.tokenizer is None or body.tokenizer == ""
 | 
					 | 
				
			||||||
        else body.tokenizer
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        global_var.set(
 | 
					        global_var.set(
 | 
				
			||||||
            global_var.Model,
 | 
					            global_var.Model,
 | 
				
			||||||
            TextRWKV(
 | 
					            RWKV(model=body.model, strategy=body.strategy, tokenizer=body.tokenizer),
 | 
				
			||||||
                model=body.model,
 | 
					 | 
				
			||||||
                strategy=body.strategy,
 | 
					 | 
				
			||||||
                tokens_path=tokenizer,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            if "midi" not in body.model.lower()
 | 
					 | 
				
			||||||
            else MusicRWKV(
 | 
					 | 
				
			||||||
                model=body.model,
 | 
					 | 
				
			||||||
                strategy=body.strategy,
 | 
					 | 
				
			||||||
                tokens_path=tokenizer,
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    except Exception as e:
 | 
					    except Exception as e:
 | 
				
			||||||
        print(e)
 | 
					        print(e)
 | 
				
			||||||
 | 
				
			|||||||
@ -4,7 +4,7 @@ import os
 | 
				
			|||||||
import pathlib
 | 
					import pathlib
 | 
				
			||||||
import copy
 | 
					import copy
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
from typing import Dict, Iterable, List, Tuple, Union
 | 
					from typing import Dict, Iterable, List, Tuple, Union, Type
 | 
				
			||||||
from utils.log import quick_log
 | 
					from utils.log import quick_log
 | 
				
			||||||
from fastapi import HTTPException
 | 
					from fastapi import HTTPException
 | 
				
			||||||
from pydantic import BaseModel, Field
 | 
					from pydantic import BaseModel, Field
 | 
				
			||||||
@ -21,33 +21,21 @@ os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.res
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RWKVType(Enum):
 | 
					class RWKVType(Enum):
 | 
				
			||||||
 | 
					    NoneType = auto()
 | 
				
			||||||
    Raven = auto()
 | 
					    Raven = auto()
 | 
				
			||||||
    World = auto()
 | 
					    World = auto()
 | 
				
			||||||
    Music = auto()
 | 
					    Music = auto()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AbstractRWKV(ABC):
 | 
					class AbstractRWKV(ABC):
 | 
				
			||||||
    def __init__(self, model: str, strategy: str, tokens_path: str):
 | 
					    def __init__(self, model, pipeline):
 | 
				
			||||||
        rwkv_beta = global_var.get(global_var.Args).rwkv_beta
 | 
					        self.name = "rwkv"
 | 
				
			||||||
 | 
					        self.model = model
 | 
				
			||||||
        # dynamic import to make RWKV_CUDA_ON work
 | 
					        self.pipeline = pipeline
 | 
				
			||||||
        if rwkv_beta:
 | 
					 | 
				
			||||||
            from rwkv_pip.beta.model import (
 | 
					 | 
				
			||||||
                RWKV as Model,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            from rwkv_pip.model import (
 | 
					 | 
				
			||||||
                RWKV as Model,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        from rwkv_pip.utils import PIPELINE
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        filename, _ = os.path.splitext(os.path.basename(model))
 | 
					 | 
				
			||||||
        self.name = filename
 | 
					 | 
				
			||||||
        self.model = Model(model, strategy)
 | 
					 | 
				
			||||||
        self.pipeline = PIPELINE(self.model, tokens_path)
 | 
					 | 
				
			||||||
        self.model_state = None
 | 
					        self.model_state = None
 | 
				
			||||||
        self.model_tokens = []
 | 
					        self.model_tokens = []
 | 
				
			||||||
        self.rwkv_type: RWKVType = None
 | 
					        self.rwkv_type: RWKVType = RWKVType.NoneType
 | 
				
			||||||
 | 
					        self.tokenizer_len = len(model.w["emb.weight"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.max_tokens_per_generation = 500
 | 
					        self.max_tokens_per_generation = 500
 | 
				
			||||||
        self.temperature = 1
 | 
					        self.temperature = 1
 | 
				
			||||||
@ -348,8 +336,8 @@ class AbstractRWKV(ABC):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TextRWKV(AbstractRWKV):
 | 
					class TextRWKV(AbstractRWKV):
 | 
				
			||||||
    def __init__(self, model: str, strategy: str, tokens_path: str) -> None:
 | 
					    def __init__(self, model, pipeline) -> None:
 | 
				
			||||||
        super().__init__(model, strategy, tokens_path)
 | 
					        super().__init__(model, pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.CHUNK_LEN = 256
 | 
					        self.CHUNK_LEN = 256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -361,16 +349,16 @@ class TextRWKV(AbstractRWKV):
 | 
				
			|||||||
        self.penalty_alpha_frequency = 1
 | 
					        self.penalty_alpha_frequency = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.interface = ":"
 | 
					        self.interface = ":"
 | 
				
			||||||
        if "world" in self.name.lower():
 | 
					        if self.tokenizer_len < 65536:
 | 
				
			||||||
            self.rwkv_type = RWKVType.World
 | 
					 | 
				
			||||||
            self.user = "Question"
 | 
					 | 
				
			||||||
            self.bot = "Answer"
 | 
					 | 
				
			||||||
            self.END_OF_LINE = 11
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.rwkv_type = RWKVType.Raven
 | 
					            self.rwkv_type = RWKVType.Raven
 | 
				
			||||||
            self.user = "Bob"
 | 
					            self.user = "Bob"
 | 
				
			||||||
            self.bot = "Alice"
 | 
					            self.bot = "Alice"
 | 
				
			||||||
            self.END_OF_LINE = 187
 | 
					            self.END_OF_LINE = 187
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.rwkv_type = RWKVType.World
 | 
				
			||||||
 | 
					            self.user = "Question"
 | 
				
			||||||
 | 
					            self.bot = "Answer"
 | 
				
			||||||
 | 
					            self.END_OF_LINE = 11
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.AVOID_REPEAT_TOKENS = []
 | 
					        self.AVOID_REPEAT_TOKENS = []
 | 
				
			||||||
        AVOID_REPEAT = ",:?!"
 | 
					        AVOID_REPEAT = ",:?!"
 | 
				
			||||||
@ -469,8 +457,8 @@ The following is a coherent verbose detailed conversation between a girl named {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MusicRWKV(AbstractRWKV):
 | 
					class MusicRWKV(AbstractRWKV):
 | 
				
			||||||
    def __init__(self, model: str, strategy: str, tokens_path: str):
 | 
					    def __init__(self, model, pipeline):
 | 
				
			||||||
        super().__init__(model, strategy, tokens_path)
 | 
					        super().__init__(model, pipeline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.max_tokens_per_generation = 500
 | 
					        self.max_tokens_per_generation = 500
 | 
				
			||||||
        self.temperature = 1
 | 
					        self.temperature = 1
 | 
				
			||||||
@ -510,6 +498,52 @@ class MusicRWKV(AbstractRWKV):
 | 
				
			|||||||
        return " " + delta
 | 
					        return " " + delta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_tokenizer(tokenizer_len: int):
 | 
				
			||||||
 | 
					    tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
 | 
				
			||||||
 | 
					    if tokenizer_len < 50277:
 | 
				
			||||||
 | 
					        return tokenizer_dir + "tokenizer-midi.json"
 | 
				
			||||||
 | 
					    elif tokenizer_len < 65536:
 | 
				
			||||||
 | 
					        return tokenizer_dir + "20B_tokenizer.json"
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return "rwkv_vocab_v20230424"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
 | 
				
			||||||
 | 
					    rwkv_beta = global_var.get(global_var.Args).rwkv_beta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # dynamic import to make RWKV_CUDA_ON work
 | 
				
			||||||
 | 
					    if rwkv_beta:
 | 
				
			||||||
 | 
					        from rwkv_pip.beta.model import (
 | 
				
			||||||
 | 
					            RWKV as Model,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        from rwkv_pip.model import (
 | 
				
			||||||
 | 
					            RWKV as Model,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    from rwkv_pip.utils import PIPELINE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    filename, _ = os.path.splitext(os.path.basename(model))
 | 
				
			||||||
 | 
					    model = Model(model, strategy)
 | 
				
			||||||
 | 
					    if not tokenizer:
 | 
				
			||||||
 | 
					        tokenizer = get_tokenizer(len(model.w["emb.weight"]))
 | 
				
			||||||
 | 
					    pipeline = PIPELINE(model, tokenizer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rwkv_map: dict[str, Type[AbstractRWKV]] = {
 | 
				
			||||||
 | 
					        "20B_tokenizer": TextRWKV,
 | 
				
			||||||
 | 
					        "rwkv_vocab_v20230424": TextRWKV,
 | 
				
			||||||
 | 
					        "tokenizer-midi": MusicRWKV,
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
 | 
				
			||||||
 | 
					    rwkv: AbstractRWKV
 | 
				
			||||||
 | 
					    if tokenizer_name in rwkv_map:
 | 
				
			||||||
 | 
					        rwkv = rwkv_map[tokenizer_name](model, pipeline)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        rwkv = TextRWKV(model, pipeline)
 | 
				
			||||||
 | 
					    rwkv.name = filename
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return rwkv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ModelConfigBody(BaseModel):
 | 
					class ModelConfigBody(BaseModel):
 | 
				
			||||||
    max_tokens: int = Field(default=None, gt=0, le=102400)
 | 
					    max_tokens: int = Field(default=None, gt=0, le=102400)
 | 
				
			||||||
    temperature: float = Field(default=None, ge=0, le=2)
 | 
					    temperature: float = Field(default=None, ge=0, le=2)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user