RWKVType now no longer relies on the file name

This commit is contained in:
josc146 2023-10-26 16:55:33 +08:00
parent 1d7f19ffaf
commit 627a20936d
2 changed files with 65 additions and 62 deletions

View File

@ -10,22 +10,6 @@ import global_var
router = APIRouter() router = APIRouter()
def get_tokens_path(model_path: str):
model_path = model_path.lower()
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
default_tokens_path = tokenizer_dir + "20B_tokenizer.json"
if "raven" in model_path:
return default_tokens_path
elif "world" in model_path:
return "rwkv_vocab_v20230424"
elif "midi" in model_path:
return tokenizer_dir + "tokenizer-midi.json"
else:
return default_tokens_path
class SwitchModelBody(BaseModel): class SwitchModelBody(BaseModel):
model: str model: str
strategy: str strategy: str
@ -67,25 +51,10 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0" os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading) global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
tokenizer = (
get_tokens_path(body.model)
if body.tokenizer is None or body.tokenizer == ""
else body.tokenizer
)
try: try:
global_var.set( global_var.set(
global_var.Model, global_var.Model,
TextRWKV( RWKV(model=body.model, strategy=body.strategy, tokenizer=body.tokenizer),
model=body.model,
strategy=body.strategy,
tokens_path=tokenizer,
)
if "midi" not in body.model.lower()
else MusicRWKV(
model=body.model,
strategy=body.strategy,
tokens_path=tokenizer,
),
) )
except Exception as e: except Exception as e:
print(e) print(e)

View File

@ -4,7 +4,7 @@ import os
import pathlib import pathlib
import copy import copy
import re import re
from typing import Dict, Iterable, List, Tuple, Union from typing import Dict, Iterable, List, Tuple, Union, Type
from utils.log import quick_log from utils.log import quick_log
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -21,33 +21,21 @@ os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.res
class RWKVType(Enum): class RWKVType(Enum):
NoneType = auto()
Raven = auto() Raven = auto()
World = auto() World = auto()
Music = auto() Music = auto()
class AbstractRWKV(ABC): class AbstractRWKV(ABC):
def __init__(self, model: str, strategy: str, tokens_path: str): def __init__(self, model, pipeline):
rwkv_beta = global_var.get(global_var.Args).rwkv_beta self.name = "rwkv"
self.model = model
# dynamic import to make RWKV_CUDA_ON work self.pipeline = pipeline
if rwkv_beta:
from rwkv_pip.beta.model import (
RWKV as Model,
)
else:
from rwkv_pip.model import (
RWKV as Model,
)
from rwkv_pip.utils import PIPELINE
filename, _ = os.path.splitext(os.path.basename(model))
self.name = filename
self.model = Model(model, strategy)
self.pipeline = PIPELINE(self.model, tokens_path)
self.model_state = None self.model_state = None
self.model_tokens = [] self.model_tokens = []
self.rwkv_type: RWKVType = None self.rwkv_type: RWKVType = RWKVType.NoneType
self.tokenizer_len = len(model.w["emb.weight"])
self.max_tokens_per_generation = 500 self.max_tokens_per_generation = 500
self.temperature = 1 self.temperature = 1
@ -348,8 +336,8 @@ class AbstractRWKV(ABC):
class TextRWKV(AbstractRWKV): class TextRWKV(AbstractRWKV):
def __init__(self, model: str, strategy: str, tokens_path: str) -> None: def __init__(self, model, pipeline) -> None:
super().__init__(model, strategy, tokens_path) super().__init__(model, pipeline)
self.CHUNK_LEN = 256 self.CHUNK_LEN = 256
@ -361,16 +349,16 @@ class TextRWKV(AbstractRWKV):
self.penalty_alpha_frequency = 1 self.penalty_alpha_frequency = 1
self.interface = ":" self.interface = ":"
if "world" in self.name.lower(): if self.tokenizer_len < 65536:
self.rwkv_type = RWKVType.World
self.user = "Question"
self.bot = "Answer"
self.END_OF_LINE = 11
else:
self.rwkv_type = RWKVType.Raven self.rwkv_type = RWKVType.Raven
self.user = "Bob" self.user = "Bob"
self.bot = "Alice" self.bot = "Alice"
self.END_OF_LINE = 187 self.END_OF_LINE = 187
else:
self.rwkv_type = RWKVType.World
self.user = "Question"
self.bot = "Answer"
self.END_OF_LINE = 11
self.AVOID_REPEAT_TOKENS = [] self.AVOID_REPEAT_TOKENS = []
AVOID_REPEAT = "" AVOID_REPEAT = ""
@ -469,8 +457,8 @@ The following is a coherent verbose detailed conversation between a girl named {
class MusicRWKV(AbstractRWKV): class MusicRWKV(AbstractRWKV):
def __init__(self, model: str, strategy: str, tokens_path: str): def __init__(self, model, pipeline):
super().__init__(model, strategy, tokens_path) super().__init__(model, pipeline)
self.max_tokens_per_generation = 500 self.max_tokens_per_generation = 500
self.temperature = 1 self.temperature = 1
@ -510,6 +498,52 @@ class MusicRWKV(AbstractRWKV):
return " " + delta return " " + delta
def get_tokenizer(tokenizer_len: int):
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
if tokenizer_len < 50277:
return tokenizer_dir + "tokenizer-midi.json"
elif tokenizer_len < 65536:
return tokenizer_dir + "20B_tokenizer.json"
else:
return "rwkv_vocab_v20230424"
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
# dynamic import to make RWKV_CUDA_ON work
if rwkv_beta:
from rwkv_pip.beta.model import (
RWKV as Model,
)
else:
from rwkv_pip.model import (
RWKV as Model,
)
from rwkv_pip.utils import PIPELINE
filename, _ = os.path.splitext(os.path.basename(model))
model = Model(model, strategy)
if not tokenizer:
tokenizer = get_tokenizer(len(model.w["emb.weight"]))
pipeline = PIPELINE(model, tokenizer)
rwkv_map: dict[str, Type[AbstractRWKV]] = {
"20B_tokenizer": TextRWKV,
"rwkv_vocab_v20230424": TextRWKV,
"tokenizer-midi": MusicRWKV,
}
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
rwkv: AbstractRWKV
if tokenizer_name in rwkv_map:
rwkv = rwkv_map[tokenizer_name](model, pipeline)
else:
rwkv = TextRWKV(model, pipeline)
rwkv.name = filename
return rwkv
class ModelConfigBody(BaseModel): class ModelConfigBody(BaseModel):
max_tokens: int = Field(default=None, gt=0, le=102400) max_tokens: int = Field(default=None, gt=0, le=102400)
temperature: float = Field(default=None, ge=0, le=2) temperature: float = Field(default=None, ge=0, le=2)