improve startup process

This commit is contained in:
josc146 2023-11-04 20:21:55 +08:00
parent 1f81a1e5a8
commit 1dcda47013

View File

@ -2,16 +2,47 @@ import time
start_time = time.time() start_time = time.time()
import setuptools # avoid warnings import argparse
from typing import Union, Sequence
def get_args(args: Union[Sequence[str], None] = None):
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title="server arguments")
group.add_argument(
"--port",
type=int,
default=8000,
help="port to run the server on (default: 8000)",
)
group.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="host to run the server on (default: 127.0.0.1)",
)
group = parser.add_argument_group(title="mode arguments")
group.add_argument(
"--rwkv-beta",
action="store_true",
help="whether to use rwkv-beta (default: False)",
)
args = parser.parse_args(args)
return args
if __name__ == "__main__":
args = get_args()
import os import os
import sys import sys
import argparse
from typing import Sequence
from contextlib import asynccontextmanager
sys.path.append(os.path.dirname(os.path.realpath(__file__))) sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import psutil import psutil
from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import uvicorn import uvicorn
@ -77,34 +108,7 @@ def exit():
parent.kill() parent.kill()
def get_args(args: Union[Sequence[str], None] = None):
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title="server arguments")
group.add_argument(
"--port",
type=int,
default=8000,
help="port to run the server on (default: 8000)",
)
group.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="host to run the server on (default: 127.0.0.1)",
)
group = parser.add_argument_group(title="mode arguments")
group.add_argument(
"--rwkv-beta",
action="store_true",
help="whether to use rwkv-beta (default: False)",
)
args = parser.parse_args(args)
return args
if __name__ == "__main__": if __name__ == "__main__":
args = get_args()
os.environ["RWKV_RUNNER_PARAMS"] = " ".join(sys.argv[1:]) os.environ["RWKV_RUNNER_PARAMS"] = " ".join(sys.argv[1:])
print("--- %s seconds ---" % (time.time() - start_time)) print("--- %s seconds ---" % (time.time() - start_time))
uvicorn.run("main:app", port=args.port, host=args.host, workers=1) uvicorn.run("main:app", port=args.port, host=args.host, workers=1)