fixed torch version; CUDA acceleration utils
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Dict
|
||||
from langchain.llms import RWKV
|
||||
from pydantic import BaseModel
|
||||
@@ -34,6 +36,10 @@ def get_rwkv_config(model: RWKV) -> ModelConfigBody:
|
||||
)
|
||||
|
||||
|
||||
# os.environ["RWKV_CUDA_ON"] = '1'
|
||||
# os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
|
||||
|
||||
|
||||
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
|
||||
model.model_state = None
|
||||
model.model_tokens = []
|
||||
|
||||
Reference in New Issue
Block a user