import os
import sys

sys.path.append(os.path.dirname(os.path.realpath(__file__)))

import psutil
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

from utils.rwkv import *
from utils.torch import *
from utils.ngrok import *
from utils.log import log_middleware
from routes import completion, config, state_cache, midi
import global_var

app = FastAPI(dependencies=[Depends(log_middleware)])

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.include_router(completion.router)
app.include_router(config.router)
app.include_router(midi.router)
app.include_router(state_cache.router)


@app.on_event("startup")
def init():
    global_var.init()
    state_cache.init()

    set_torch()

    if os.environ.get("ngrok_token") is not None:
        ngrok_connect()


@app.get("/", tags=["Root"])
def read_root():
    return {"Hello": "World!"}


@app.post("/exit", tags=["Root"])
def exit():
    parent_pid = os.getpid()
    parent = psutil.Process(parent_pid)
    for child in parent.children(recursive=True):
        child.kill()
    parent.kill()


if __name__ == "__main__":
    uvicorn.run(
        "main:app",
        port=8000 if len(sys.argv) < 2 else int(sys.argv[1]),
        host="127.0.0.1" if len(sys.argv) < 3 else sys.argv[2],
    )