add support for dynamic state-tuned models

This commit is contained in:
josc146
2024-05-12 21:51:24 +08:00
parent b52873cb37
commit a2bbbabee2
12 changed files with 230 additions and 15 deletions

View File

@@ -354,5 +354,10 @@
"Inside the model, there is a default prompt to improve the model's handling of common issues, but it may degrade the role-playing effect. You can disable this option to achieve a better role-playing effect.": "モデル内部には、一般的な問題の処理を改善するためのデフォルトのプロンプトがありますが、役割演技の効果を低下させる可能性があります。このオプションを無効にすることで、より良い役割演技効果を得ることができます。",
"Exit without saving": "保存せずに終了",
"Content has been changed, are you sure you want to exit without saving?": "コンテンツが変更されています、保存せずに終了してもよろしいですか?",
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "Ollama APIチャットモデル名を正しく記入するのを忘れないでください。"
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "Ollama APIチャットモデル名を正しく記入するのを忘れないでください。",
"State-tuned Model": "State調整モデル",
"See More": "もっと見る",
"State Model": "Stateモデル",
"State model mismatch": "Stateモデルの不一致",
"File format of the model or state model not supported": "モデルまたはStateモデルのファイル形式がサポートされていません"
}

View File

@@ -354,5 +354,10 @@
"Inside the model, there is a default prompt to improve the model's handling of common issues, but it may degrade the role-playing effect. You can disable this option to achieve a better role-playing effect.": "模型内部有一个默认提示来改善模型处理常规问题的效果, 但它可能会让角色扮演的效果变差, 你可以关闭此选项来获得更好的角色扮演效果",
"Exit without saving": "退出而不保存",
"Content has been changed, are you sure you want to exit without saving?": "内容已经被修改, 你确定要退出而不保存吗?",
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "不要忘记正确填写你的Ollama API 聊天模型名"
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "不要忘记正确填写你的Ollama API 聊天模型名",
"State-tuned Model": "State微调模型",
"See More": "查看更多",
"State Model": "State模型",
"State model mismatch": "State模型不匹配",
"File format of the model or state model not supported": "模型或state模型的文件格式不支持"
}

View File

@@ -214,7 +214,18 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
presence_penalty: modelConfig.apiParameters.presencePenalty,
frequency_penalty: modelConfig.apiParameters.frequencyPenalty,
penalty_decay: modelConfig.apiParameters.penaltyDecay,
global_penalty: modelConfig.apiParameters.globalPenalty
global_penalty: modelConfig.apiParameters.globalPenalty,
state: modelConfig.apiParameters.stateModel
}).then(async r => {
if (r.status !== 200) {
const error = await r.text();
if (error.includes('state shape mismatch'))
toast(t('State model mismatch'), { type: 'error' });
else if (error.includes('file format of the model or state model not supported'))
toast(t('File format of the model or state model not supported'), { type: 'error' });
else
toast(error, { type: 'error' });
}
});
}

View File

@@ -7,11 +7,13 @@ import {
Dropdown,
Input,
Label,
Link,
Option,
PresenceBadge,
Select,
Switch,
Text
Text,
Tooltip
} from '@fluentui/react-components';
import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons';
import React, { FC, useCallback, useEffect, useRef } from 'react';
@@ -27,7 +29,7 @@ import { Page } from '../components/Page';
import { useNavigate } from 'react-router';
import { RunButton } from '../components/RunButton';
import { updateConfig } from '../apis';
import { getStrategy } from '../utils';
import { getStrategy, isDynamicStateSupported } from '../utils';
import { useTranslation } from 'react-i18next';
import strategyImg from '../assets/images/strategy.jpg';
import strategyZhImg from '../assets/images/strategy_zh.jpg';
@@ -36,6 +38,7 @@ import { useMediaQuery } from 'usehooks-ts';
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
import { defaultPenaltyDecay } from './defaultConfigs';
import { BrowserOpenURL } from '../../wailsjs/runtime';
const ConfigSelector: FC<{
selectedIndex: number,
@@ -112,6 +115,8 @@ const Configs: FC = observer(() => {
const onClickSave = () => {
commonStore.setModelConfig(selectedIndex, selectedConfig);
// When clicking RunButton in Configs page, updateConfig will be called twice,
// because there are also RunButton in other pages, and the calls to updateConfig in both places are necessary.
updateConfig({
max_tokens: selectedConfig.apiParameters.maxResponseToken,
temperature: selectedConfig.apiParameters.temperature,
@@ -119,7 +124,18 @@ const Configs: FC = observer(() => {
presence_penalty: selectedConfig.apiParameters.presencePenalty,
frequency_penalty: selectedConfig.apiParameters.frequencyPenalty,
penalty_decay: selectedConfig.apiParameters.penaltyDecay,
global_penalty: selectedConfig.apiParameters.globalPenalty
global_penalty: selectedConfig.apiParameters.globalPenalty,
state: selectedConfig.apiParameters.stateModel
}).then(async r => {
if (r.status !== 200) {
const error = await r.text();
if (error.includes('state shape mismatch'))
toast(t('State model mismatch'), { type: 'error' });
else if (error.includes('file format of the model or state model not supported'))
toast(t('File format of the model or state model not supported'), { type: 'error' });
else
toast(error, { type: 'error' });
}
});
toast(t('Config Saved'), { autoClose: 300, type: 'success' });
};
@@ -200,6 +216,34 @@ const Configs: FC = observer(() => {
});
}} />
} />
{isDynamicStateSupported(selectedConfig) &&
<div className="sm:col-span-2 flex gap-2 items-center min-w-0">
<Tooltip content={<div>
{t('State-tuned Model')}, {t('See More')}: <Link
onClick={() => BrowserOpenURL('https://github.com/BlinkDL/RWKV-LM#state-tuning-tuning-the-initial-state-zero-inference-overhead')}>{'https://github.com/BlinkDL/RWKV-LM#state-tuning-tuning-the-initial-state-zero-inference-overhead'}
</Link>
</div>} showDelay={0} hideDelay={0}
relationship="description">
<div className="shrink-0">
{t('State Model') + ' *'}
</div>
</Tooltip>
<Select style={{ minWidth: 0 }} className="grow"
value={selectedConfig.apiParameters.stateModel}
onChange={(e, data) => {
setSelectedConfigApiParams({
stateModel: data.value
});
}}>
<option key={-1} value={''}>
{t('None')}
</option>
{commonStore.stateModels.map((modelName, index) =>
<option key={index} value={modelName}>{modelName}</option>
)}
</Select>
</div>
}
<Accordion className="sm:col-span-2" collapsible
openItems={!commonStore.apiParamsCollapsed && 'advanced'}
onToggle={(e, data) => {

View File

@@ -1,6 +1,14 @@
import commonStore, { MonitorData, Platform } from './stores/commonStore';
import { FileExists, GetPlatform, ListDirFiles, ReadJson } from '../wailsjs/go/backend_golang/App';
import { Cache, checkUpdate, downloadProgramFiles, LocalConfig, refreshLocalModels, refreshModels } from './utils';
import {
bytesToMb,
Cache,
checkUpdate,
downloadProgramFiles,
LocalConfig,
refreshLocalModels,
refreshModels
} from './utils';
import { getStatus } from './apis';
import { EventsOn, WindowSetTitle } from '../wailsjs/runtime';
import manifest from '../../manifest.json';
@@ -29,6 +37,7 @@ export async function startup() {
});
initLocalModelsNotify();
initLoraModels();
initStateModels();
initHardwareMonitor();
initMidi();
}
@@ -124,12 +133,42 @@ async function initLoraModels() {
});
}
async function initStateModels() {
const refreshStateModels = throttle(async () => {
const stateModels = await ListDirFiles('state-models').then((data) => {
if (!data) return [];
const stateModels = [];
for (const f of data) {
if (!f.isDir && f.name.endsWith('.pth')) {
stateModels.push('state-models/' + f.name);
}
}
return stateModels;
});
await ListDirFiles('models').then((data) => {
if (!data) return;
for (const f of data) {
if (!f.isDir && f.name.endsWith('.pth') && Number(bytesToMb(f.size)) < 200) {
stateModels.push('models/' + f.name);
}
}
});
commonStore.setStateModels(stateModels);
}, 2000);
refreshStateModels();
EventsOn('fsnotify', (data: string) => {
if ((data.includes('models') && !data.includes('lora-models')) || data.includes('state-models'))
refreshStateModels();
});
}
async function initLocalModelsNotify() {
const throttleRefreshLocalModels = throttle(() => {
refreshLocalModels({ models: commonStore.modelSourceList }, false); //TODO fix bug that only add models
}, 2000);
EventsOn('fsnotify', (data: string) => {
if (data.includes('models') && !data.includes('lora-models'))
if (data.includes('models') && !data.includes('lora-models') && !data.includes('state-models'))
throttleRefreshLocalModels();
});
}

View File

@@ -65,6 +65,7 @@ class CommonStore {
platform: Platform = 'windows';
proxyPort: number = 0;
lastModelName: string = '';
stateModels: string[] = [];
// presets manager
editingPreset: Preset | null = null;
presets: Preset[] = [];
@@ -410,6 +411,10 @@ class CommonStore {
this.loraModels = value;
}
setStateModels(value: string[]) {
this.stateModels = value;
}
setAttachmentUploading(value: boolean) {
this.attachmentUploading = value;
}

View File

@@ -7,6 +7,7 @@ export type ApiParameters = {
frequencyPenalty: number;
penaltyDecay?: number;
globalPenalty?: boolean;
stateModel?: string;
}
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'WebGPU (Python)' | 'MPS' | 'Custom';
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';

View File

@@ -676,4 +676,11 @@ export function newChatConversation() {
commonStore.setConversationOrder(conversationOrder);
};
return { pushMessage, saveConversation };
}
export function isDynamicStateSupported(modelConfig: ModelConfig) {
return modelConfig.modelParameters.device === 'CUDA' ||
modelConfig.modelParameters.device === 'CPU' ||
modelConfig.modelParameters.device === 'Custom' ||
modelConfig.modelParameters.device === 'MPS';
}