backend api
This commit is contained in:
9
backend-python/utils/ngrok.py
Normal file
9
backend-python/utils/ngrok.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import os
|
||||
|
||||
|
||||
def ngrok_connect():
|
||||
from pyngrok import ngrok, conf
|
||||
conf.set_default(conf.PyngrokConfig(ngrok_path="./ngrok"))
|
||||
ngrok.set_auth_token(os.environ["ngrok_token"])
|
||||
http_tunnel = ngrok.connect(8000)
|
||||
print(http_tunnel.public_url)
|
||||
40
backend-python/utils/rwkv.py
Normal file
40
backend-python/utils/rwkv.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Dict
|
||||
from langchain.llms import RWKV
|
||||
|
||||
|
||||
def rwkv_generate(model: RWKV, prompt: str):
|
||||
model.model_state = None
|
||||
model.model_tokens = []
|
||||
logits = model.run_rnn(model.tokenizer.encode(prompt).ids)
|
||||
begin = len(model.model_tokens)
|
||||
out_last = begin
|
||||
|
||||
occurrence: Dict = {}
|
||||
|
||||
response = ""
|
||||
for i in range(model.max_tokens_per_generation):
|
||||
for n in occurrence:
|
||||
logits[n] -= (
|
||||
model.penalty_alpha_presence
|
||||
+ occurrence[n] * model.penalty_alpha_frequency
|
||||
)
|
||||
token = model.pipeline.sample_logits(
|
||||
logits, temperature=model.temperature, top_p=model.top_p
|
||||
)
|
||||
|
||||
END_OF_TEXT = 0
|
||||
if token == END_OF_TEXT:
|
||||
break
|
||||
if token not in occurrence:
|
||||
occurrence[token] = 1
|
||||
else:
|
||||
occurrence[token] += 1
|
||||
|
||||
logits = model.run_rnn([token])
|
||||
delta: str = model.tokenizer.decode(model.model_tokens[out_last:])
|
||||
if "\ufffd" not in delta: # avoid utf-8 display issues
|
||||
response += delta
|
||||
yield response, delta
|
||||
out_last = begin + i + 1
|
||||
if i >= model.max_tokens_per_generation - 100:
|
||||
break
|
||||
26
backend-python/utils/torch.py
Normal file
26
backend-python/utils/torch.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import os
|
||||
import sysconfig
|
||||
|
||||
|
||||
def set_torch():
|
||||
torch_path = os.path.join(sysconfig.get_paths()["purelib"], "torch\\lib")
|
||||
paths = os.environ.get("PATH", "")
|
||||
if os.path.exists(torch_path):
|
||||
print(f"torch found: {torch_path}")
|
||||
if torch_path in paths:
|
||||
print("torch already set")
|
||||
else:
|
||||
print("run:")
|
||||
os.environ['PATH'] = paths + os.pathsep + torch_path + os.pathsep
|
||||
print(f'set Path={paths + os.pathsep + torch_path + os.pathsep}')
|
||||
else:
|
||||
print("torch not found")
|
||||
|
||||
|
||||
def torch_gc():
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(0):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
Reference in New Issue
Block a user