custom tokenizer (#77)
This commit is contained in:
parent
971124d0d7
commit
a25965530c
@ -29,6 +29,7 @@ def get_tokens_path(model_path: str):
|
|||||||
class SwitchModelBody(BaseModel):
|
class SwitchModelBody(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
strategy: str
|
strategy: str
|
||||||
|
tokenizer: Union[str, None] = None
|
||||||
customCuda: bool = False
|
customCuda: bool = False
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -36,6 +37,7 @@ class SwitchModelBody(BaseModel):
|
|||||||
"example": {
|
"example": {
|
||||||
"model": "models/RWKV-4-World-3B-v1-20230619-ctx4096.pth",
|
"model": "models/RWKV-4-World-3B-v1-20230619-ctx4096.pth",
|
||||||
"strategy": "cuda fp16",
|
"strategy": "cuda fp16",
|
||||||
|
"tokenizer": None,
|
||||||
"customCuda": False,
|
"customCuda": False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -65,19 +67,24 @@ def switch_model(body: SwitchModelBody, response: Response, request: Request):
|
|||||||
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
os.environ["RWKV_CUDA_ON"] = "1" if body.customCuda else "0"
|
||||||
|
|
||||||
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
global_var.set(global_var.Model_Status, global_var.ModelStatus.Loading)
|
||||||
|
tokenizer = (
|
||||||
|
get_tokens_path(body.model)
|
||||||
|
if body.tokenizer is None or body.tokenizer == ""
|
||||||
|
else body.tokenizer
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
global_var.set(
|
global_var.set(
|
||||||
global_var.Model,
|
global_var.Model,
|
||||||
TextRWKV(
|
TextRWKV(
|
||||||
model=body.model,
|
model=body.model,
|
||||||
strategy=body.strategy,
|
strategy=body.strategy,
|
||||||
tokens_path=get_tokens_path(body.model),
|
tokens_path=tokenizer,
|
||||||
)
|
)
|
||||||
if "midi" not in body.model.lower()
|
if "midi" not in body.model.lower()
|
||||||
else MusicRWKV(
|
else MusicRWKV(
|
||||||
model=body.model,
|
model=body.model,
|
||||||
strategy=body.strategy,
|
strategy=body.strategy,
|
||||||
tokens_path=get_tokens_path(body.model),
|
tokens_path=tokenizer,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -248,5 +248,7 @@
|
|||||||
"Preview Only": "プレビューのみ",
|
"Preview Only": "プレビューのみ",
|
||||||
"RAM": "RAM",
|
"RAM": "RAM",
|
||||||
"VRAM": "VRAM",
|
"VRAM": "VRAM",
|
||||||
"GPU Usage": "GPU使用率"
|
"GPU Usage": "GPU使用率",
|
||||||
|
"Use Custom Tokenizer": "カスタムトークナイザーを使用する",
|
||||||
|
"Tokenizer Path": "トークナイザーパス"
|
||||||
}
|
}
|
@ -248,5 +248,7 @@
|
|||||||
"Preview Only": "仅预览",
|
"Preview Only": "仅预览",
|
||||||
"RAM": "内存",
|
"RAM": "内存",
|
||||||
"VRAM": "显存",
|
"VRAM": "显存",
|
||||||
"GPU Usage": "GPU占用"
|
"GPU Usage": "GPU占用",
|
||||||
|
"Use Custom Tokenizer": "使用自定义Tokenizer",
|
||||||
|
"Tokenizer Path": "Tokenizer路径"
|
||||||
}
|
}
|
@ -186,6 +186,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
|||||||
switchModel({
|
switchModel({
|
||||||
model: modelPath,
|
model: modelPath,
|
||||||
strategy: strategy,
|
strategy: strategy,
|
||||||
|
tokenizer: modelConfig.modelParameters.useCustomTokenizer ? modelConfig.modelParameters.customTokenizer : undefined,
|
||||||
customCuda: customCudaFile !== ''
|
customCuda: customCudaFile !== ''
|
||||||
}).then(async (r) => {
|
}).then(async (r) => {
|
||||||
if (r.ok) {
|
if (r.ok) {
|
||||||
|
@ -1,6 +1,19 @@
|
|||||||
import { Dropdown, Input, Label, Option, Select, Switch, Text } from '@fluentui/react-components';
|
import {
|
||||||
|
Accordion,
|
||||||
|
AccordionHeader,
|
||||||
|
AccordionItem,
|
||||||
|
AccordionPanel,
|
||||||
|
Checkbox,
|
||||||
|
Dropdown,
|
||||||
|
Input,
|
||||||
|
Label,
|
||||||
|
Option,
|
||||||
|
Select,
|
||||||
|
Switch,
|
||||||
|
Text
|
||||||
|
} from '@fluentui/react-components';
|
||||||
import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons';
|
import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons';
|
||||||
import React, { FC } from 'react';
|
import React, { FC, useEffect, useRef } from 'react';
|
||||||
import { Section } from '../components/Section';
|
import { Section } from '../components/Section';
|
||||||
import { Labeled } from '../components/Labeled';
|
import { Labeled } from '../components/Labeled';
|
||||||
import { ToolTipButton } from '../components/ToolTipButton';
|
import { ToolTipButton } from '../components/ToolTipButton';
|
||||||
@ -43,6 +56,8 @@ export type ModelParameters = {
|
|||||||
maxStoredLayers: number;
|
maxStoredLayers: number;
|
||||||
useCustomCuda?: boolean;
|
useCustomCuda?: boolean;
|
||||||
customStrategy?: string;
|
customStrategy?: string;
|
||||||
|
useCustomTokenizer?: boolean;
|
||||||
|
customTokenizer?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ModelConfig = {
|
export type ModelConfig = {
|
||||||
@ -57,10 +72,16 @@ export const Configs: FC = observer(() => {
|
|||||||
const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex);
|
const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex);
|
||||||
const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]);
|
const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]);
|
||||||
const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false);
|
const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false);
|
||||||
|
const advancedHeaderRef = useRef<HTMLDivElement>(null);
|
||||||
const mq = useMediaQuery('(min-width: 640px)');
|
const mq = useMediaQuery('(min-width: 640px)');
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const port = selectedConfig.apiParameters.apiPort;
|
const port = selectedConfig.apiParameters.apiPort;
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (advancedHeaderRef.current)
|
||||||
|
(advancedHeaderRef.current.firstElementChild as HTMLElement).style.padding = '0';
|
||||||
|
}, []);
|
||||||
|
|
||||||
const updateSelectedIndex = (newIndex: number) => {
|
const updateSelectedIndex = (newIndex: number) => {
|
||||||
setSelectedIndex(newIndex);
|
setSelectedIndex(newIndex);
|
||||||
setSelectedConfig(commonStore.modelConfigs[newIndex]);
|
setSelectedConfig(commonStore.modelConfigs[newIndex]);
|
||||||
@ -412,6 +433,39 @@ export const Configs: FC = observer(() => {
|
|||||||
}} />
|
}} />
|
||||||
} />
|
} />
|
||||||
}
|
}
|
||||||
|
{selectedConfig.modelParameters.device !== 'WebGPU' &&
|
||||||
|
<Accordion className="sm:col-span-2" collapsible
|
||||||
|
openItems={!commonStore.modelParamsCollapsed && 'advanced'}
|
||||||
|
onToggle={(e, data) => {
|
||||||
|
if (data.value === 'advanced')
|
||||||
|
commonStore.setModelParamsCollapsed(!commonStore.modelParamsCollapsed);
|
||||||
|
}}>
|
||||||
|
<AccordionItem value="advanced">
|
||||||
|
<AccordionHeader ref={advancedHeaderRef} size="small">{t('Advanced')}</AccordionHeader>
|
||||||
|
<AccordionPanel>
|
||||||
|
<div className="flex flex-col">
|
||||||
|
<div className="flex grow">
|
||||||
|
<Checkbox className="select-none"
|
||||||
|
size="large" label={t('Use Custom Tokenizer')}
|
||||||
|
checked={selectedConfig.modelParameters.useCustomTokenizer}
|
||||||
|
onChange={(_, data) => {
|
||||||
|
setSelectedConfigModelParams({
|
||||||
|
useCustomTokenizer: data.checked as boolean
|
||||||
|
});
|
||||||
|
}} />
|
||||||
|
<Input className="grow" placeholder={t('Tokenizer Path')!}
|
||||||
|
value={selectedConfig.modelParameters.customTokenizer}
|
||||||
|
onChange={(e, data) => {
|
||||||
|
setSelectedConfigModelParams({
|
||||||
|
customTokenizer: data.value
|
||||||
|
});
|
||||||
|
}} />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
</Accordion>
|
||||||
|
}
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
|
@ -74,6 +74,7 @@ class CommonStore {
|
|||||||
// configs
|
// configs
|
||||||
currentModelConfigIndex: number = 0;
|
currentModelConfigIndex: number = 0;
|
||||||
modelConfigs: ModelConfig[] = [];
|
modelConfigs: ModelConfig[] = [];
|
||||||
|
modelParamsCollapsed: boolean = true;
|
||||||
// models
|
// models
|
||||||
modelSourceManifestList: string = 'https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/manifest.json;';
|
modelSourceManifestList: string = 'https://cdn.jsdelivr.net/gh/josstorer/RWKV-Runner@master/manifest.json;';
|
||||||
modelSourceList: ModelSourceItem[] = [];
|
modelSourceList: ModelSourceItem[] = [];
|
||||||
@ -259,6 +260,10 @@ class CommonStore {
|
|||||||
this.advancedCollapsed = value;
|
this.advancedCollapsed = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setModelParamsCollapsed(value: boolean) {
|
||||||
|
this.modelParamsCollapsed = value;
|
||||||
|
}
|
||||||
|
|
||||||
setLastUnfinishedModelDownloads(value: DownloadStatus[]) {
|
setLastUnfinishedModelDownloads(value: DownloadStatus[]) {
|
||||||
this.lastUnfinishedModelDownloads = value;
|
this.lastUnfinishedModelDownloads = value;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user