bump webgpu(python) (https://github.com/cryscan/web-rwkv-py)
This commit is contained in:
parent
6b4381ee77
commit
3488d22d22
@ -96,7 +96,9 @@ def copy_tensor_to_cpu(tensors):
|
|||||||
elif tensors_type == np.ndarray: # rwkv.cpp
|
elif tensors_type == np.ndarray: # rwkv.cpp
|
||||||
copied = tensors
|
copied = tensors
|
||||||
else: # WebGPU state
|
else: # WebGPU state
|
||||||
copied = tensors.back()
|
model = global_var.get(global_var.Model)
|
||||||
|
if model:
|
||||||
|
copied = model.model.model.back_state()
|
||||||
|
|
||||||
return copied, devices
|
return copied, devices
|
||||||
|
|
||||||
@ -238,11 +240,14 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|||||||
state: Union[Any, None] = v["state"]
|
state: Union[Any, None] = v["state"]
|
||||||
logits: Union[Any, None] = v["logits"]
|
logits: Union[Any, None] = v["logits"]
|
||||||
|
|
||||||
if type(state) == list and hasattr(state[0], "device"): # torch
|
state_type = type(state)
|
||||||
|
if state_type == list and hasattr(state[0], "device"): # torch
|
||||||
state = [
|
state = [
|
||||||
|
(
|
||||||
tensor.to(devices[i])
|
tensor.to(devices[i])
|
||||||
if devices[i] != torch.device("cpu")
|
if devices[i] != torch.device("cpu")
|
||||||
else tensor.clone()
|
else tensor.clone()
|
||||||
|
)
|
||||||
for i, tensor in enumerate(state)
|
for i, tensor in enumerate(state)
|
||||||
]
|
]
|
||||||
logits = (
|
logits = (
|
||||||
@ -250,7 +255,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
|||||||
if logits_device != torch.device("cpu")
|
if logits_device != torch.device("cpu")
|
||||||
else logits.clone()
|
else logits.clone()
|
||||||
)
|
)
|
||||||
else: # rwkv.cpp, WebGPU
|
elif state_type == np.ndarray: # rwkv.cpp
|
||||||
|
logits = np.copy(logits)
|
||||||
|
else: # WebGPU
|
||||||
logits = np.copy(logits)
|
logits = np.copy(logits)
|
||||||
|
|
||||||
quick_log(request, body, "Hit:\n" + prompt)
|
quick_log(request, body, "Hit:\n" + prompt)
|
||||||
|
33
backend-python/rwkv_pip/webgpu/model.py
vendored
33
backend-python/rwkv_pip/webgpu/model.py
vendored
@ -13,13 +13,6 @@ except ModuleNotFoundError:
|
|||||||
|
|
||||||
class RWKV:
|
class RWKV:
|
||||||
def __init__(self, model_path: str, strategy: str = None):
|
def __init__(self, model_path: str, strategy: str = None):
|
||||||
self.info = wrp.peek_info(model_path)
|
|
||||||
self.w = {} # fake weight
|
|
||||||
self.w["emb.weight"] = [0] * self.info.num_vocab
|
|
||||||
self.version = str(self.info.version).lower()
|
|
||||||
self.wrp = getattr(wrp, self.version)
|
|
||||||
self.version = float(self.version.replace("v", ""))
|
|
||||||
|
|
||||||
layer = (
|
layer = (
|
||||||
int(s.lstrip("layer"))
|
int(s.lstrip("layer"))
|
||||||
for s in strategy.split()
|
for s in strategy.split()
|
||||||
@ -33,21 +26,25 @@ class RWKV:
|
|||||||
for s in s.split(",")
|
for s in s.split(",")
|
||||||
if s.startswith("chunk")
|
if s.startswith("chunk")
|
||||||
)
|
)
|
||||||
|
self.token_chunk_size = next(chunk_size, 32)
|
||||||
|
|
||||||
args = {
|
args = {
|
||||||
"file": model_path,
|
"path": model_path,
|
||||||
"turbo": True,
|
|
||||||
"quant": next(layer, 31) if "i8" in strategy else 0,
|
"quant": next(layer, 31) if "i8" in strategy else 0,
|
||||||
"quant_nf4": next(layer, 26) if "i4" in strategy else 0,
|
"quant_nf4": next(layer, 26) if "i4" in strategy else 0,
|
||||||
"token_chunk_size": next(chunk_size, 32),
|
|
||||||
"lora": None,
|
|
||||||
}
|
}
|
||||||
self.model = self.wrp.Model(**args)
|
self.model = wrp.Model(**args)
|
||||||
|
self.info = self.model.info()
|
||||||
|
self.w = {} # fake weight
|
||||||
|
self.w["emb.weight"] = [0] * self.info.num_vocab
|
||||||
|
self.version = str(self.info.version).lower()
|
||||||
|
self.version = float(self.version.lower().replace("v", ""))
|
||||||
|
|
||||||
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
def forward(self, tokens: List[int], state: Union[Any, None] = None):
|
||||||
if type(state).__name__ == "BackedState": # memory state
|
if state is None:
|
||||||
gpu_state = self.wrp.ModelState(self.model, 1)
|
self.model.clear_state()
|
||||||
gpu_state.load(state)
|
elif type(state).__name__ == "State_Cpu":
|
||||||
else:
|
self.model.load_state(state)
|
||||||
gpu_state = state
|
logits = self.model.run(tokens, self.token_chunk_size)
|
||||||
return self.wrp.run_one(self.model, tokens, gpu_state)
|
ret_state = "State_Gpu"
|
||||||
|
return logits, ret_state
|
||||||
|
Binary file not shown.
Loading…
Reference in New Issue
Block a user