bump webgpu(python) (https://github.com/cryscan/web-rwkv-py)
This commit is contained in:
@@ -96,7 +96,9 @@ def copy_tensor_to_cpu(tensors):
|
||||
elif tensors_type == np.ndarray: # rwkv.cpp
|
||||
copied = tensors
|
||||
else: # WebGPU state
|
||||
copied = tensors.back()
|
||||
model = global_var.get(global_var.Model)
|
||||
if model:
|
||||
copied = model.model.model.back_state()
|
||||
|
||||
return copied, devices
|
||||
|
||||
@@ -238,11 +240,14 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
state: Union[Any, None] = v["state"]
|
||||
logits: Union[Any, None] = v["logits"]
|
||||
|
||||
if type(state) == list and hasattr(state[0], "device"): # torch
|
||||
state_type = type(state)
|
||||
if state_type == list and hasattr(state[0], "device"): # torch
|
||||
state = [
|
||||
tensor.to(devices[i])
|
||||
if devices[i] != torch.device("cpu")
|
||||
else tensor.clone()
|
||||
(
|
||||
tensor.to(devices[i])
|
||||
if devices[i] != torch.device("cpu")
|
||||
else tensor.clone()
|
||||
)
|
||||
for i, tensor in enumerate(state)
|
||||
]
|
||||
logits = (
|
||||
@@ -250,7 +255,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
if logits_device != torch.device("cpu")
|
||||
else logits.clone()
|
||||
)
|
||||
else: # rwkv.cpp, WebGPU
|
||||
elif state_type == np.ndarray: # rwkv.cpp
|
||||
logits = np.copy(logits)
|
||||
else: # WebGPU
|
||||
logits = np.copy(logits)
|
||||
|
||||
quick_log(request, body, "Hit:\n" + prompt)
|
||||
|
||||
Reference in New Issue
Block a user