fix body.state
This commit is contained in:
parent
7f3cfd54b0
commit
7e2380e4ed
@ -94,7 +94,7 @@ def add_state(body: AddStateBody):
|
||||
state: Union[Any, None] = None
|
||||
|
||||
if body.state is not None:
|
||||
if type(state) == list and hasattr(state[0], "device"): # torch
|
||||
if type(body.state) == list and hasattr(body.state[0], "device"): # torch
|
||||
devices = [tensor.device for tensor in body.state]
|
||||
state = [tensor.cpu() for tensor in body.state]
|
||||
elif type(body.state) == np.ndarray: # rwkv.cpp
|
||||
|
Loading…
Reference in New Issue
Block a user