This commit is contained in:
josc146
2024-05-28 21:27:10 +08:00
parent 6b4381ee77
commit 3488d22d22
3 changed files with 28 additions and 24 deletions

View File

@@ -96,7 +96,9 @@ def copy_tensor_to_cpu(tensors):
elif tensors_type == np.ndarray: # rwkv.cpp
copied = tensors
else: # WebGPU state
copied = tensors.back()
model = global_var.get(global_var.Model)
if model:
copied = model.model.model.back_state()
return copied, devices
@@ -238,11 +240,14 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
state: Union[Any, None] = v["state"]
logits: Union[Any, None] = v["logits"]
if type(state) == list and hasattr(state[0], "device"): # torch
state_type = type(state)
if state_type == list and hasattr(state[0], "device"): # torch
state = [
tensor.to(devices[i])
if devices[i] != torch.device("cpu")
else tensor.clone()
(
tensor.to(devices[i])
if devices[i] != torch.device("cpu")
else tensor.clone()
)
for i, tensor in enumerate(state)
]
logits = (
@@ -250,7 +255,9 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
if logits_device != torch.device("cpu")
else logits.clone()
)
else: # rwkv.cpp, WebGPU
elif state_type == np.ndarray: # rwkv.cpp
logits = np.copy(logits)
else: # WebGPU
logits = np.copy(logits)
quick_log(request, body, "Hit:\n" + prompt)