add support for dynamic state-tuned models

This commit is contained in:
josc146
2024-05-12 21:51:24 +08:00
parent b52873cb37
commit a2bbbabee2
12 changed files with 230 additions and 15 deletions

View File

@@ -120,6 +120,9 @@ def update_config(body: ModelConfigBody):
model_config = ModelConfigBody()
global_var.set(global_var.Model_Config, model_config)
merge_model(model_config, body)
exception = load_rwkv_state(global_var.get(global_var.Model), model_config.state)
if exception is not None:
raise exception
print("Updated Model Config:", model_config)
return "success"

View File

@@ -176,6 +176,19 @@ def reset_state():
return "success"
def force_reset_state():
global trie, dtrie
if trie is None:
return
import cyac
trie = cyac.Trie()
dtrie = {}
gc.collect()
class LongestPrefixStateBody(BaseModel):
prompt: str

View File

@@ -7,7 +7,7 @@ import re
import time
from typing import Dict, Iterable, List, Tuple, Union, Type, Callable
from utils.log import quick_log
from fastapi import HTTPException
from fastapi import HTTPException, status
from pydantic import BaseModel, Field
from routes import state_cache
import global_var
@@ -27,6 +27,7 @@ class AbstractRWKV(ABC):
self.EOS_ID = 0
self.name = "rwkv"
self.model_path = ""
self.version = 4
self.model = model
self.pipeline = pipeline
@@ -43,6 +44,8 @@ class AbstractRWKV(ABC):
self.penalty_alpha_frequency = 1
self.penalty_decay = 0.996
self.global_penalty = False
self.state_path = ""
self.state_tuned = None
@abstractmethod
def adjust_occurrence(self, occurrence: Dict, token: int):
@@ -236,7 +239,10 @@ class AbstractRWKV(ABC):
except HTTPException:
pass
if cache is None or cache["prompt"] == "" or cache["state"] is None:
self.model_state = None
if self.state_path:
self.model_state = copy.deepcopy(self.state_tuned)
else:
self.model_state = None
self.model_tokens = []
else:
delta_prompt = prompt[len(cache["prompt"]) :]
@@ -606,13 +612,13 @@ def get_model_path(model_path: str) -> str:
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
model = get_model_path(model)
model_path = get_model_path(model)
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
webgpu = global_var.get(global_var.Args).webgpu
if "midi" in model.lower() or "abc" in model.lower():
if "midi" in model_path.lower() or "abc" in model_path.lower():
os.environ["RWKV_RESCALE_LAYER"] = "999"
# dynamic import to make RWKV_CUDA_ON work
@@ -637,8 +643,8 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
)
from rwkv_pip.utils import PIPELINE
filename, _ = os.path.splitext(os.path.basename(model))
model = Model(model, strategy)
filename, _ = os.path.splitext(os.path.basename(model_path))
model = Model(model_path, strategy)
if not tokenizer:
tokenizer = get_tokenizer(len(model.w["emb.weight"]))
pipeline = PIPELINE(model, tokenizer)
@@ -671,6 +677,7 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
else:
rwkv = TextRWKV(model, pipeline)
rwkv.name = filename
rwkv.model_path = model_path
rwkv.version = model.version
return rwkv
@@ -688,6 +695,7 @@ class ModelConfigBody(BaseModel):
default=None,
description="When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.",
)
state: str = Field(default=None, description="state-tuned file path")
model_config = {
"json_schema_extra": {
@@ -699,11 +707,80 @@ class ModelConfigBody(BaseModel):
"frequency_penalty": 1,
"penalty_decay": 0.996,
"global_penalty": False,
"state": "",
}
}
}
def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
if model:
if state_path:
if model.model_path.endswith(".pth") and state_path.endswith(".pth"):
import torch
state_path = get_model_path(state_path)
if model.state_path == state_path:
return
state_raw = torch.load(state_path, map_location="cpu")
state_raw_shape = next(iter(state_raw.values())).shape
args = model.model.args
if (
len(state_raw) != args.n_layer
or state_raw_shape[0] * state_raw_shape[1] != args.n_embd
):
if model.state_path:
pass
else:
print("state failed to load")
return HTTPException(
status.HTTP_400_BAD_REQUEST, "state shape mismatch"
)
strategy = model.model.strategy
model.state_tuned = [None] * args.n_layer * 3
for i in range(args.n_layer):
dd = strategy[i]
dev = dd.device
atype = dd.atype
model.state_tuned[i * 3 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
model.state_tuned[i * 3 + 1] = (
state_raw[f"blocks.{i}.att.time_state"]
.transpose(1, 2)
.to(dtype=torch.float, device=dev)
.requires_grad_(False)
.contiguous()
)
model.state_tuned[i * 3 + 2] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
state_cache.force_reset_state()
model.state_path = state_path
print("state loaded")
else:
if model.state_path:
pass
else:
print("state failed to load")
return HTTPException(
status.HTTP_400_BAD_REQUEST,
"file format of the model or state model not supported",
)
else:
state_cache.force_reset_state()
model.state_path = ""
model.state_tuned = None # TODO cached
print("state unloaded")
else:
print("state not loaded")
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
if body.max_tokens is not None:
model.max_tokens_per_generation = body.max_tokens
@@ -724,6 +801,8 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
model.top_k = body.top_k
if body.global_penalty is not None:
model.global_penalty = body.global_penalty
if body.state is not None:
load_rwkv_state(model, body.state)
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
@@ -736,4 +815,5 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
penalty_decay=model.penalty_decay,
top_k=model.top_k,
global_penalty=model.global_penalty,
state=model.state_path,
)