fixed torch version; CUDA acceleration utils

This commit is contained in:
josc146
2023-05-23 11:19:39 +08:00
parent ecb5d6c6e4
commit 7989e93afe
6 changed files with 745 additions and 2 deletions

View File

@@ -1,3 +1,5 @@
import os
import pathlib
from typing import Dict
from langchain.llms import RWKV
from pydantic import BaseModel
@@ -34,6 +36,10 @@ def get_rwkv_config(model: RWKV) -> ModelConfigBody:
)
# os.environ["RWKV_CUDA_ON"] = '1'
# os.environ["TORCH_EXTENSIONS_DIR"] = f"{pathlib.Path(__file__).parent.parent.resolve()}"
def rwkv_generate(model: RWKV, prompt: str, stop: str = None):
model.model_state = None
model.model_tokens = []