RWKVType now no longer relies on the file name
This commit is contained in:
parent
1d7f19ffaf
commit
627a20936d
@ -10,22 +10,6 @@ import global_var
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_path(model_path: str):
|
|
||||||
model_path = model_path.lower()
|
|
||||||
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
|
|
||||||
|
|
||||||
default_tokens_path = tokenizer_dir + "20B_tokenizer.json"
|
|
||||||
|
|
||||||
if "raven" in model_path:
|
|
||||||
return default_tokens_path
|
|
||||||
elif "world" in model_path:
|
|
||||||
return "rwkv_vocab_v20230424"
|
|
||||||
elif "midi" in model_path:
|
|
||||||
return tokenizer_dir + "tokenizer-midi.json"
|
|
||||||
else:
|
|
||||||
return default_tokens_path
|
|
||||||
|
|
||||||
|
|
||||||
class SwitchModelBody(BaseModel):
|
class SwitchModelBody(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
strategy: str
|
strategy: str
|
||||||
@ -67,25 +51,10 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
|||||||
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
||||||
|
|
||||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
||||||
tokenizer = (
|
|
||||||
get_tokens_path(body.model)
|
|
||||||
if body.tokenizer is None or body.tokenizer == ""
|
|
||||||
else body.tokenizer
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
global_var.set(
|
global_var.set(
|
||||||
global_var.Model,
|
global_var.Model,
|
||||||
TextRWKV(
|
RWKV(model=body.model, strategy=body.strategy, tokenizer=body.tokenizer),
|
||||||
model=body.model,
|
|
||||||
strategy=body.strategy,
|
|
||||||
tokens_path=tokenizer,
|
|
||||||
)
|
|
||||||
if "midi" not in body.model.lower()
|
|
||||||
else MusicRWKV(
|
|
||||||
model=body.model,
|
|
||||||
strategy=body.strategy,
|
|
||||||
tokens_path=tokenizer,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
@ -4,7 +4,7 @@ import os
|
|||||||
import pathlib
|
import pathlib
|
||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
from typing import Dict, Iterable, List, Tuple, Union
|
from typing import Dict, Iterable, List, Tuple, Union, Type
|
||||||
from utils.log import quick_log
|
from utils.log import quick_log
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@ -21,33 +21,21 @@ os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.res
|
|||||||
|
|
||||||
|
|
||||||
class RWKVType(Enum):
|
class RWKVType(Enum):
|
||||||
|
NoneType = auto()
|
||||||
Raven = auto()
|
Raven = auto()
|
||||||
World = auto()
|
World = auto()
|
||||||
Music = auto()
|
Music = auto()
|
||||||
|
|
||||||
|
|
||||||
class AbstractRWKV(ABC):
|
class AbstractRWKV(ABC):
|
||||||
def __init__(self, model: str, strategy: str, tokens_path: str):
|
def __init__(self, model, pipeline):
|
||||||
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
self.name = "rwkv"
|
||||||
|
self.model = model
|
||||||
# dynamic import to make RWKV_CUDA_ON work
|
self.pipeline = pipeline
|
||||||
if rwkv_beta:
|
|
||||||
from rwkv_pip.beta.model import (
|
|
||||||
RWKV as Model,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from rwkv_pip.model import (
|
|
||||||
RWKV as Model,
|
|
||||||
)
|
|
||||||
from rwkv_pip.utils import PIPELINE
|
|
||||||
|
|
||||||
filename, _ = os.path.splitext(os.path.basename(model))
|
|
||||||
self.name = filename
|
|
||||||
self.model = Model(model, strategy)
|
|
||||||
self.pipeline = PIPELINE(self.model, tokens_path)
|
|
||||||
self.model_state = None
|
self.model_state = None
|
||||||
self.model_tokens = []
|
self.model_tokens = []
|
||||||
self.rwkv_type: RWKVType = None
|
self.rwkv_type: RWKVType = RWKVType.NoneType
|
||||||
|
self.tokenizer_len = len(model.w["emb.weight"])
|
||||||
|
|
||||||
self.max_tokens_per_generation = 500
|
self.max_tokens_per_generation = 500
|
||||||
self.temperature = 1
|
self.temperature = 1
|
||||||
@ -348,8 +336,8 @@ class AbstractRWKV(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class TextRWKV(AbstractRWKV):
|
class TextRWKV(AbstractRWKV):
|
||||||
def __init__(self, model: str, strategy: str, tokens_path: str) -> None:
|
def __init__(self, model, pipeline) -> None:
|
||||||
super().__init__(model, strategy, tokens_path)
|
super().__init__(model, pipeline)
|
||||||
|
|
||||||
self.CHUNK_LEN = 256
|
self.CHUNK_LEN = 256
|
||||||
|
|
||||||
@ -361,16 +349,16 @@ class TextRWKV(AbstractRWKV):
|
|||||||
self.penalty_alpha_frequency = 1
|
self.penalty_alpha_frequency = 1
|
||||||
|
|
||||||
self.interface = ":"
|
self.interface = ":"
|
||||||
if "world" in self.name.lower():
|
if self.tokenizer_len < 65536:
|
||||||
self.rwkv_type = RWKVType.World
|
|
||||||
self.user = "Question"
|
|
||||||
self.bot = "Answer"
|
|
||||||
self.END_OF_LINE = 11
|
|
||||||
else:
|
|
||||||
self.rwkv_type = RWKVType.Raven
|
self.rwkv_type = RWKVType.Raven
|
||||||
self.user = "Bob"
|
self.user = "Bob"
|
||||||
self.bot = "Alice"
|
self.bot = "Alice"
|
||||||
self.END_OF_LINE = 187
|
self.END_OF_LINE = 187
|
||||||
|
else:
|
||||||
|
self.rwkv_type = RWKVType.World
|
||||||
|
self.user = "Question"
|
||||||
|
self.bot = "Answer"
|
||||||
|
self.END_OF_LINE = 11
|
||||||
|
|
||||||
self.AVOID_REPEAT_TOKENS = []
|
self.AVOID_REPEAT_TOKENS = []
|
||||||
AVOID_REPEAT = ",:?!"
|
AVOID_REPEAT = ",:?!"
|
||||||
@ -469,8 +457,8 @@ The following is a coherent verbose detailed conversation between a girl named {
|
|||||||
|
|
||||||
|
|
||||||
class MusicRWKV(AbstractRWKV):
|
class MusicRWKV(AbstractRWKV):
|
||||||
def __init__(self, model: str, strategy: str, tokens_path: str):
|
def __init__(self, model, pipeline):
|
||||||
super().__init__(model, strategy, tokens_path)
|
super().__init__(model, pipeline)
|
||||||
|
|
||||||
self.max_tokens_per_generation = 500
|
self.max_tokens_per_generation = 500
|
||||||
self.temperature = 1
|
self.temperature = 1
|
||||||
@ -510,6 +498,52 @@ class MusicRWKV(AbstractRWKV):
|
|||||||
return " " + delta
|
return " " + delta
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(tokenizer_len: int):
|
||||||
|
tokenizer_dir = f"{pathlib.Path(__file__).parent.parent.resolve()}/rwkv_pip/"
|
||||||
|
if tokenizer_len < 50277:
|
||||||
|
return tokenizer_dir + "tokenizer-midi.json"
|
||||||
|
elif tokenizer_len < 65536:
|
||||||
|
return tokenizer_dir + "20B_tokenizer.json"
|
||||||
|
else:
|
||||||
|
return "rwkv_vocab_v20230424"
|
||||||
|
|
||||||
|
|
||||||
|
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
|
||||||
|
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
||||||
|
|
||||||
|
# dynamic import to make RWKV_CUDA_ON work
|
||||||
|
if rwkv_beta:
|
||||||
|
from rwkv_pip.beta.model import (
|
||||||
|
RWKV as Model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from rwkv_pip.model import (
|
||||||
|
RWKV as Model,
|
||||||
|
)
|
||||||
|
from rwkv_pip.utils import PIPELINE
|
||||||
|
|
||||||
|
filename, _ = os.path.splitext(os.path.basename(model))
|
||||||
|
model = Model(model, strategy)
|
||||||
|
if not tokenizer:
|
||||||
|
tokenizer = get_tokenizer(len(model.w["emb.weight"]))
|
||||||
|
pipeline = PIPELINE(model, tokenizer)
|
||||||
|
|
||||||
|
rwkv_map: dict[str, Type[AbstractRWKV]] = {
|
||||||
|
"20B_tokenizer": TextRWKV,
|
||||||
|
"rwkv_vocab_v20230424": TextRWKV,
|
||||||
|
"tokenizer-midi": MusicRWKV,
|
||||||
|
}
|
||||||
|
tokenizer_name = os.path.splitext(os.path.basename(tokenizer))[0]
|
||||||
|
rwkv: AbstractRWKV
|
||||||
|
if tokenizer_name in rwkv_map:
|
||||||
|
rwkv = rwkv_map[tokenizer_name](model, pipeline)
|
||||||
|
else:
|
||||||
|
rwkv = TextRWKV(model, pipeline)
|
||||||
|
rwkv.name = filename
|
||||||
|
|
||||||
|
return rwkv
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigBody(BaseModel):
|
class ModelConfigBody(BaseModel):
|
||||||
max_tokens: int = Field(default=None, gt=0, le=102400)
|
max_tokens: int = Field(default=None, gt=0, le=102400)
|
||||||
temperature: float = Field(default=None, ge=0, le=2)
|
temperature: float = Field(default=None, ge=0, le=2)
|
||||||
|
Loading…
Reference in New Issue
Block a user