custom tokenizer (#77)

This commit is contained in:
josc146 2023-09-16 00:34:11 +08:00
parent 971124d0d7
commit a25965530c
6 changed files with 77 additions and 6 deletions

View File

@ -29,6 +29,7 @@ def get_tokens_path(model_path: str):
class SwitchModelBody(BaseModel): class SwitchModelBody(BaseModel):
model: str model: str
strategy: str strategy: str
tokenizer: Union[str, None] = None
customCuda: bool = False customCuda: bool = False
class Config: class Config:
@ -36,6 +37,7 @@ class SwitchModelBody(BaseModel):
"example": { "example": {
"model": "models/RWKV-4-World-3B-v1-20230619-ctx4096.pth", "model": "models/RWKV-4-World-3B-v1-20230619-ctx4096.pth",
"strategy": "cuda fp16", "strategy": "cuda fp16",
"tokenizer": None,
"customCuda": False, "customCuda": False,
} }
} }
@ -65,19 +67,24 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0" os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading) global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
tokenizer = (
get_tokens_path(body.model)
if body.tokenizer is None or body.tokenizer == ""
else body.tokenizer
)
try: try:
global_var.set( global_var.set(
global_var.Model, global_var.Model,
TextRWKV( TextRWKV(
model=body.model, model=body.model,
strategy=body.strategy, strategy=body.strategy,
tokens_path=get_tokens_path(body.model), tokens_path=tokenizer,
) )
if "midi" not in body.model.lower() if "midi" not in body.model.lower()
else MusicRWKV( else MusicRWKV(
model=body.model, model=body.model,
strategy=body.strategy, strategy=body.strategy,
tokens_path=get_tokens_path(body.model), tokens_path=tokenizer,
), ),
) )
except Exception as e: except Exception as e:

View File

@ -248,5 +248,7 @@
"Preview Only": "プレビューのみ", "Preview Only": "プレビューのみ",
"RAM": "RAM", "RAM": "RAM",
"VRAM": "VRAM", "VRAM": "VRAM",
"GPU Usage": "GPU使用率" "GPU Usage": "GPU使用率",
"Use Custom Tokenizer": "カスタムトークナイザーを使用する",
"Tokenizer Path": "トークナイザーパス"
} }

View File

@ -248,5 +248,7 @@
"Preview Only": "仅预览", "Preview Only": "仅预览",
"RAM": "内存", "RAM": "内存",
"VRAM": "显存", "VRAM": "显存",
"GPU Usage": "GPU占用" "GPU Usage": "GPU占用",
"Use Custom Tokenizer": "使用自定义Tokenizer",
"Tokenizer Path": "Tokenizer路径"
} }

View File

@ -186,6 +186,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
switchModel({ switchModel({
model: modelPath, model: modelPath,
strategy: strategy, strategy: strategy,
tokenizer: modelConfig.modelParameters.useCustomTokenizer ? modelConfig.modelParameters.customTokenizer : undefined,
customCuda: customCudaFile !== '' customCuda: customCudaFile !== ''
}).then(async (r) => { }).then(async (r) => {
if (r.ok) { if (r.ok) {

View File

@ -1,6 +1,19 @@
import { Dropdown, Input, Label, Option, Select, Switch, Text } from '@fluentui/react-components'; import {
Accordion,
AccordionHeader,
AccordionItem,
AccordionPanel,
Checkbox,
Dropdown,
Input,
Label,
Option,
Select,
Switch,
Text
} from '@fluentui/react-components';
import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons'; import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons';
import React, { FC } from 'react'; import React, { FC, useEffect, useRef } from 'react';
import { Section } from '../components/Section'; import { Section } from '../components/Section';
import { Labeled } from '../components/Labeled'; import { Labeled } from '../components/Labeled';
import { ToolTipButton } from '../components/ToolTipButton'; import { ToolTipButton } from '../components/ToolTipButton';
@ -43,6 +56,8 @@ export type ModelParameters = {
maxStoredLayers: number; maxStoredLayers: number;
useCustomCuda?: boolean; useCustomCuda?: boolean;
customStrategy?: string; customStrategy?: string;
useCustomTokenizer?: boolean;
customTokenizer?: string;
} }
export type ModelConfig = { export type ModelConfig = {
@ -57,10 +72,16 @@ export const Configs: FC = observer(() => {
const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex); const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex);
const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]); const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]);
const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false); const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false);
const advancedHeaderRef = useRef<HTMLDivElement>(null);
const mq = useMediaQuery('(min-width: 640px)'); const mq = useMediaQuery('(min-width: 640px)');
const navigate = useNavigate(); const navigate = useNavigate();
const port = selectedConfig.apiParameters.apiPort; const port = selectedConfig.apiParameters.apiPort;
useEffect(() => {
if (advancedHeaderRef.current)
(advancedHeaderRef.current.firstElementChild as HTMLElement).style.padding = '0';
}, []);
const updateSelectedIndex = (newIndex: number) => { const updateSelectedIndex = (newIndex: number) => {
setSelectedIndex(newIndex); setSelectedIndex(newIndex);
setSelectedConfig(commonStore.modelConfigs[newIndex]); setSelectedConfig(commonStore.modelConfigs[newIndex]);
@ -412,6 +433,39 @@ export const Configs: FC = observer(() => {
}} /> }} />
} /> } />
} }
{selectedConfig.modelParameters.device !== 'WebGPU' &&
<Accordion className="sm:col-span-2" collapsible
openItems={!commonStore.modelParamsCollapsed && 'advanced'}
onToggle={(e, data) => {
if (data.value === 'advanced')
commonStore.setModelParamsCollapsed(!commonStore.modelParamsCollapsed);
}}>
<AccordionItem value="advanced">
<AccordionHeader ref={advancedHeaderRef} size="small">{t('Advanced')}</AccordionHeader>
<AccordionPanel>
<div className="flex flex-col">
<div className="flex grow">
<Checkbox className="select-none"
size="large" label={t('Use Custom Tokenizer')}
checked={selectedConfig.modelParameters.useCustomTokenizer}
onChange={(_, data) => {
setSelectedConfigModelParams({
useCustomTokenizer: data.checked as boolean
});
}} />
<Input className="grow" placeholder={t('Tokenizer Path')!}
value={selectedConfig.modelParameters.customTokenizer}
onChange={(e, data) => {
setSelectedConfigModelParams({
customTokenizer: data.value
});
}} />
</div>
</div>
</AccordionPanel>
</AccordionItem>
</Accordion>
}
</div> </div>
} }
/> />

View File

@ -74,6 +74,7 @@ class CommonStore {
// configs // configs
currentModelConfigIndex: number = 0; currentModelConfigIndex: number = 0;
modelConfigs: ModelConfig[] = []; modelConfigs: ModelConfig[] = [];
modelParamsCollapsed: boolean = true;
// models // models
modelSourceManifestList: string = 'https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/manifest.json;'; modelSourceManifestList: string = 'https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/manifest.json;';
modelSourceList: ModelSourceItem[] = []; modelSourceList: ModelSourceItem[] = [];
@ -259,6 +260,10 @@ class CommonStore {
this.advancedCollapsed = value; this.advancedCollapsed = value;
} }
setModelParamsCollapsed(value: boolean) {
this.modelParamsCollapsed = value;
}
setLastUnfinishedModelDownloads(value: DownloadStatus[]) { setLastUnfinishedModelDownloads(value: DownloadStatus[]) {
this.lastUnfinishedModelDownloads = value; this.lastUnfinishedModelDownloads = value;
} }