diff --git a/backend-python/routes/config.py b/backend-python/routes/config.py index e21fbcf..9be72b7 100644 --- a/backend-python/routes/config.py +++ b/backend-python/routes/config.py @@ -29,6 +29,7 @@ def get_tokens_path(model_path: str): class SwitchModelBody(BaseModel): model: str strategy: str + tokenizer: Union[str, None] = None customCuda: bool = False class Config: @@ -36,6 +37,7 @@ class SwitchModelBody(BaseModel): "example": { "model": "models/RWKV-4-World-3B-v1-20230619-ctx4096.pth", "strategy": "cuda fp16", + "tokenizer": None, "customCuda": False, } } @@ -65,19 +67,24 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request): os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0" global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading) + tokenizer = ( + get_tokens_path(body.model) + if body.tokenizer is None or body.tokenizer == "" + else body.tokenizer + ) try: global_var.set( global_var.Model, TextRWKV( model=body.model, strategy=body.strategy, - tokens_path=get_tokens_path(body.model), + tokens_path=tokenizer, ) if "midi" not in body.model.lower() else MusicRWKV( model=body.model, strategy=body.strategy, - tokens_path=get_tokens_path(body.model), + tokens_path=tokenizer, ), ) except Exception as e: diff --git a/frontend/src/_locales/ja/main.json b/frontend/src/_locales/ja/main.json index 23b667b..6e08a60 100644 --- a/frontend/src/_locales/ja/main.json +++ b/frontend/src/_locales/ja/main.json @@ -248,5 +248,7 @@ "Preview Only": "プレビューのみ", "RAM": "RAM", "VRAM": "VRAM", - "GPU Usage": "GPU使用率" + "GPU Usage": "GPU使用率", + "Use Custom Tokenizer": "カスタムトークナイザーを使用する", + "Tokenizer Path": "トークナイザーパス" } \ No newline at end of file diff --git a/frontend/src/_locales/zh-hans/main.json b/frontend/src/_locales/zh-hans/main.json index 2e5536c..16bed8e 100644 --- a/frontend/src/_locales/zh-hans/main.json +++ b/frontend/src/_locales/zh-hans/main.json @@ -248,5 +248,7 @@ "Preview Only": "仅预览", "RAM": "内存", "VRAM": "显存", - "GPU Usage": "GPU占用" + "GPU Usage": "GPU占用", + "Use Custom Tokenizer": "使用自定义Tokenizer", + "Tokenizer Path": "Tokenizer路径" } \ No newline at end of file diff --git a/frontend/src/components/RunButton.tsx b/frontend/src/components/RunButton.tsx index 58fcf57..2f43faa 100644 --- a/frontend/src/components/RunButton.tsx +++ b/frontend/src/components/RunButton.tsx @@ -186,6 +186,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean switchModel({ model: modelPath, strategy: strategy, + tokenizer: modelConfig.modelParameters.useCustomTokenizer ? modelConfig.modelParameters.customTokenizer : undefined, customCuda: customCudaFile !== '' }).then(async (r) => { if (r.ok) { diff --git a/frontend/src/pages/Configs.tsx b/frontend/src/pages/Configs.tsx index a6d8296..d26c5ce 100644 --- a/frontend/src/pages/Configs.tsx +++ b/frontend/src/pages/Configs.tsx @@ -1,6 +1,19 @@ -import { Dropdown, Input, Label, Option, Select, Switch, Text } from '@fluentui/react-components'; +import { + Accordion, + AccordionHeader, + AccordionItem, + AccordionPanel, + Checkbox, + Dropdown, + Input, + Label, + Option, + Select, + Switch, + Text +} from '@fluentui/react-components'; import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons'; -import React, { FC } from 'react'; +import React, { FC, useEffect, useRef } from 'react'; import { Section } from '../components/Section'; import { Labeled } from '../components/Labeled'; import { ToolTipButton } from '../components/ToolTipButton'; @@ -43,6 +56,8 @@ export type ModelParameters = { maxStoredLayers: number; useCustomCuda?: boolean; customStrategy?: string; + useCustomTokenizer?: boolean; + customTokenizer?: string; } export type ModelConfig = { @@ -57,10 +72,16 @@ export const Configs: FC = observer(() => { const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex); const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]); const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false); + const advancedHeaderRef = useRef(null); const mq = useMediaQuery('(min-width: 640px)'); const navigate = useNavigate(); const port = selectedConfig.apiParameters.apiPort; + useEffect(() => { + if (advancedHeaderRef.current) + (advancedHeaderRef.current.firstElementChild as HTMLElement).style.padding = '0'; + }, []); + const updateSelectedIndex = (newIndex: number) => { setSelectedIndex(newIndex); setSelectedConfig(commonStore.modelConfigs[newIndex]); @@ -412,6 +433,39 @@ export const Configs: FC = observer(() => { }} /> } /> } + {selectedConfig.modelParameters.device !== 'WebGPU' && + { + if (data.value === 'advanced') + commonStore.setModelParamsCollapsed(!commonStore.modelParamsCollapsed); + }}> + + {t('Advanced')} + +
+
+ { + setSelectedConfigModelParams({ + useCustomTokenizer: data.checked as boolean + }); + }} /> + { + setSelectedConfigModelParams({ + customTokenizer: data.value + }); + }} /> +
+
+
+
+
+ } } /> diff --git a/frontend/src/stores/commonStore.ts b/frontend/src/stores/commonStore.ts index 2a6c12a..5ce3867 100644 --- a/frontend/src/stores/commonStore.ts +++ b/frontend/src/stores/commonStore.ts @@ -74,6 +74,7 @@ class CommonStore { // configs currentModelConfigIndex: number = 0; modelConfigs: ModelConfig[] = []; + modelParamsCollapsed: boolean = true; // models modelSourceManifestList: string = 'https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/manifest.json;'; modelSourceList: ModelSourceItem[] = []; @@ -259,6 +260,10 @@ class CommonStore { this.advancedCollapsed = value; } + setModelParamsCollapsed(value: boolean) { + this.modelParamsCollapsed = value; + } + setLastUnfinishedModelDownloads(value: DownloadStatus[]) { this.lastUnfinishedModelDownloads = value; }