add rwkv-cuda-beta support (faster)
This commit is contained in:
@@ -10,6 +10,7 @@ from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
from routes import state_cache
|
||||
import global_var
|
||||
|
||||
|
||||
END_OF_TEXT = 0
|
||||
@@ -27,7 +28,17 @@ class RWKVType(Enum):
|
||||
|
||||
class AbstractRWKV(ABC):
|
||||
def __init__(self, model: str, strategy: str, tokens_path: str):
|
||||
from rwkv.model import RWKV as Model # dynamic import to make RWKV_CUDA_ON work
|
||||
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
||||
|
||||
# dynamic import to make RWKV_CUDA_ON work
|
||||
if rwkv_beta:
|
||||
from rwkv_pip.beta.model import (
|
||||
RWKV as Model,
|
||||
)
|
||||
else:
|
||||
from rwkv.model import (
|
||||
RWKV as Model,
|
||||
)
|
||||
from rwkv_pip.utils import PIPELINE
|
||||
|
||||
filename, _ = os.path.splitext(os.path.basename(model))
|
||||
@@ -221,7 +232,7 @@ class AbstractRWKV(ABC):
|
||||
return state[0].tolist(), token_len
|
||||
|
||||
def generate(
|
||||
self, prompt: str, stop: Union[str, List[str]] = None
|
||||
self, prompt: str, stop: Union[str, List[str], None] = None
|
||||
) -> Iterable[Tuple[str, str, int, int]]:
|
||||
quick_log(None, None, "Generation Prompt:\n" + prompt)
|
||||
cache = None
|
||||
@@ -438,8 +449,10 @@ The following is a coherent verbose detailed conversation between a girl named {
|
||||
{bot} usually gives {user} kind, helpful and informative advices.\n
|
||||
"""
|
||||
if self.rwkv_type == RWKVType.Raven
|
||||
else f"{user}{interface} hi\n\n{bot}{interface} Hi. "
|
||||
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
|
||||
else (
|
||||
f"{user}{interface} hi\n\n{bot}{interface} Hi. "
|
||||
+ "I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.\n\n"
|
||||
)
|
||||
)
|
||||
logits, _ = self.run_rnn(self.fix_tokens(self.pipeline.encode(preset_system)))
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user