fix body.state

This commit is contained in:
josc146 2023-12-28 23:53:58 +08:00
parent 7f3cfd54b0
commit 7e2380e4ed

View File

@ -94,7 +94,7 @@ def add_state(body: AddStateBody):
state: Union[Any, None] = None
if body.state is not None:
if type(state) == list and hasattr(state[0], "device"): # torch
if type(body.state) == list and hasattr(body.state[0], "device"): # torch
devices = [tensor.device for tensor in body.state]
state = [tensor.cpu() for tensor in body.state]
elif type(body.state) == np.ndarray: # rwkv.cpp