diff --git a/backend-python/rwkv_pip/cpp/model.py b/backend-python/rwkv_pip/cpp/model.py index 1a5a074..5d34948 100644 --- a/backend-python/rwkv_pip/cpp/model.py +++ b/backend-python/rwkv_pip/cpp/model.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Union from . import rwkv_cpp_model from . import rwkv_cpp_shared_library @@ -10,5 +10,5 @@ class RWKV: self.w = {} # fake weight self.w["emb.weight"] = [0] * self.model.n_vocab - def forward(self, tokens: List[int], state: Any | None): + def forward(self, tokens: List[int], state: Union[Any, None] = None): return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)