This commit is contained in:
josc146 2024-05-28 21:27:10 +08:00
parent 6b4381ee77
commit 3488d22d22
3 changed files with 28 additions and 24 deletions

View File

@ -96,7 +96,9 @@ def copy_tensor_to_cpu(tensors):
elif tensors_type == np.ndarray: # rwkv.cpp elif tensors_type == np.ndarray: # rwkv.cpp
copied = tensors copied = tensors
else: # WebGPU state else: # WebGPU state
copied = tensors.back() model = global_var.get(global_var.Model)
if model:
copied = model.model.model.back_state()
return copied, devices return copied, devices
@ -238,11 +240,14 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
state: Union[Any, None] = v["state"] state: Union[Any, None] = v["state"]
logits: Union[Any, None] = v["logits"] logits: Union[Any, None] = v["logits"]
if type(state) == list and hasattr(state[0], "device"): # torch state_type = type(state)
if state_type == list and hasattr(state[0], "device"): # torch
state = [ state = [
(
tensor.to(devices[i]) tensor.to(devices[i])
if devices[i] != torch.device("cpu") if devices[i] != torch.device("cpu")
else tensor.clone() else tensor.clone()
)
for i, tensor in enumerate(state) for i, tensor in enumerate(state)
] ]
logits = ( logits = (
@ -250,7 +255,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
if logits_device != torch.device("cpu") if logits_device != torch.device("cpu")
else logits.clone() else logits.clone()
) )
else: # rwkv.cpp, WebGPU elif state_type == np.ndarray: # rwkv.cpp
logits = np.copy(logits)
else: # WebGPU
logits = np.copy(logits) logits = np.copy(logits)
quick_log(request, body, "Hit:\n" + prompt) quick_log(request, body, "Hit:\n" + prompt)

View File

@ -13,13 +13,6 @@ except ModuleNotFoundError:
class RWKV: class RWKV:
def __init__(self, model_path: str, strategy: str = None): def __init__(self, model_path: str, strategy: str = None):
self.info = wrp.peek_info(model_path)
self.w = {} # fake weight
self.w["emb.weight"] = [0] * self.info.num_vocab
self.version = str(self.info.version).lower()
self.wrp = getattr(wrp, self.version)
self.version = float(self.version.replace("v", ""))
layer = ( layer = (
int(s.lstrip("layer")) int(s.lstrip("layer"))
for s in strategy.split() for s in strategy.split()
@ -33,21 +26,25 @@ class RWKV:
for s in s.split(",") for s in s.split(",")
if s.startswith("chunk") if s.startswith("chunk")
) )
self.token_chunk_size = next(chunk_size, 32)
args = { args = {
"file": model_path, "path": model_path,
"turbo": True,
"quant": next(layer, 31) if "i8" in strategy else 0, "quant": next(layer, 31) if "i8" in strategy else 0,
"quant_nf4": next(layer, 26) if "i4" in strategy else 0, "quant_nf4": next(layer, 26) if "i4" in strategy else 0,
"token_chunk_size": next(chunk_size, 32),
"lora": None,
} }
self.model = self.wrp.Model(**args) self.model = wrp.Model(**args)
self.info = self.model.info()
self.w = {} # fake weight
self.w["emb.weight"] = [0] * self.info.num_vocab
self.version = str(self.info.version).lower()
self.version = float(self.version.lower().replace("v", ""))
def forward(self, tokens: List[int], state: Union[Any, None] = None): def forward(self, tokens: List[int], state: Union[Any, None] = None):
if type(state).__name__ == "BackedState": # memory state if state is None:
gpu_state = self.wrp.ModelState(self.model, 1) self.model.clear_state()
gpu_state.load(state) elif type(state).__name__ == "State_Cpu":
else: self.model.load_state(state)
gpu_state = state logits = self.model.run(tokens, self.token_chunk_size)
return self.wrp.run_one(self.model, tokens, gpu_state) ret_state = "State_Gpu"
return logits, ret_state