add WebGPU Python Mode (https://github.com/cryscan/web-rwkv-py)
This commit is contained in:
@@ -8,7 +8,6 @@ import base64
|
||||
from fastapi import APIRouter, Request, status, HTTPException
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import numpy as np
|
||||
import tiktoken
|
||||
from utils.rwkv import *
|
||||
from utils.log import quick_log
|
||||
@@ -396,6 +395,8 @@ class EmbeddingsBody(BaseModel):
|
||||
|
||||
|
||||
def embedding_base64(embedding: List[float]) -> str:
|
||||
import numpy as np
|
||||
|
||||
return base64.b64encode(np.array(embedding).astype(np.float32)).decode("utf-8")
|
||||
|
||||
|
||||
|
||||
@@ -87,18 +87,34 @@ def add_state(body: AddStateBody):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
devices: List[torch.device] = []
|
||||
state: Union[Any, None] = None
|
||||
|
||||
if body.state is not None:
|
||||
if type(body.state) == list or type(body.state) == np.ndarray:
|
||||
devices = [
|
||||
(
|
||||
tensor.device
|
||||
if hasattr(tensor, "device")
|
||||
else torch.device("cpu")
|
||||
)
|
||||
for tensor in body.state
|
||||
]
|
||||
state = (
|
||||
[tensor.cpu() for tensor in body.state]
|
||||
if hasattr(body.state[0], "device")
|
||||
else copy.deepcopy(body.state)
|
||||
)
|
||||
else:
|
||||
pass # WebGPU
|
||||
|
||||
id: int = trie.insert(body.prompt)
|
||||
devices: List[torch.device] = [
|
||||
(tensor.device if hasattr(tensor, "device") else torch.device("cpu"))
|
||||
for tensor in body.state
|
||||
]
|
||||
dtrie[id] = {
|
||||
"tokens": copy.deepcopy(body.tokens),
|
||||
"state": [tensor.cpu() for tensor in body.state]
|
||||
if hasattr(body.state[0], "device")
|
||||
else copy.deepcopy(body.state),
|
||||
"state": state,
|
||||
"logits": copy.deepcopy(body.logits),
|
||||
"devices": devices,
|
||||
}
|
||||
@@ -174,6 +190,7 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "trie not loaded")
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
id = -1
|
||||
try:
|
||||
@@ -185,14 +202,16 @@ def longest_prefix_state(body: LongestPrefixStateBody, request: Request):
|
||||
v = dtrie[id]
|
||||
devices: List[torch.device] = v["devices"]
|
||||
prompt: str = trie[id]
|
||||
state: Union[Any, None] = v["state"]
|
||||
|
||||
if state is not None and type(state) == list and hasattr(state[0], "device"):
|
||||
state = [tensor.to(devices[i]) for i, tensor in enumerate(state)]
|
||||
|
||||
quick_log(request, body, "Hit:\n" + prompt)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"tokens": v["tokens"],
|
||||
"state": [tensor.to(devices[i]) for i, tensor in enumerate(v["state"])]
|
||||
if hasattr(v["state"][0], "device")
|
||||
else v["state"],
|
||||
"state": state,
|
||||
"logits": v["logits"],
|
||||
}
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user