add support for dynamic state-tuned models
This commit is contained in:
@@ -354,5 +354,10 @@
|
||||
"Inside the model, there is a default prompt to improve the model's handling of common issues, but it may degrade the role-playing effect. You can disable this option to achieve a better role-playing effect.": "モデル内部には、一般的な問題の処理を改善するためのデフォルトのプロンプトがありますが、役割演技の効果を低下させる可能性があります。このオプションを無効にすることで、より良い役割演技効果を得ることができます。",
|
||||
"Exit without saving": "保存せずに終了",
|
||||
"Content has been changed, are you sure you want to exit without saving?": "コンテンツが変更されています、保存せずに終了してもよろしいですか?",
|
||||
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "Ollama APIチャットモデル名を正しく記入するのを忘れないでください。"
|
||||
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "Ollama APIチャットモデル名を正しく記入するのを忘れないでください。",
|
||||
"State-tuned Model": "State調整モデル",
|
||||
"See More": "もっと見る",
|
||||
"State Model": "Stateモデル",
|
||||
"State model mismatch": "Stateモデルの不一致",
|
||||
"File format of the model or state model not supported": "モデルまたはStateモデルのファイル形式がサポートされていません"
|
||||
}
|
||||
@@ -354,5 +354,10 @@
|
||||
"Inside the model, there is a default prompt to improve the model's handling of common issues, but it may degrade the role-playing effect. You can disable this option to achieve a better role-playing effect.": "模型内部有一个默认提示来改善模型处理常规问题的效果, 但它可能会让角色扮演的效果变差, 你可以关闭此选项来获得更好的角色扮演效果",
|
||||
"Exit without saving": "退出而不保存",
|
||||
"Content has been changed, are you sure you want to exit without saving?": "内容已经被修改, 你确定要退出而不保存吗?",
|
||||
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "不要忘记正确填写你的Ollama API 聊天模型名"
|
||||
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "不要忘记正确填写你的Ollama API 聊天模型名",
|
||||
"State-tuned Model": "State微调模型",
|
||||
"See More": "查看更多",
|
||||
"State Model": "State模型",
|
||||
"State model mismatch": "State模型不匹配",
|
||||
"File format of the model or state model not supported": "模型或state模型的文件格式不支持"
|
||||
}
|
||||
@@ -214,7 +214,18 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
presence_penalty: modelConfig.apiParameters.presencePenalty,
|
||||
frequency_penalty: modelConfig.apiParameters.frequencyPenalty,
|
||||
penalty_decay: modelConfig.apiParameters.penaltyDecay,
|
||||
global_penalty: modelConfig.apiParameters.globalPenalty
|
||||
global_penalty: modelConfig.apiParameters.globalPenalty,
|
||||
state: modelConfig.apiParameters.stateModel
|
||||
}).then(async r => {
|
||||
if (r.status !== 200) {
|
||||
const error = await r.text();
|
||||
if (error.includes('state shape mismatch'))
|
||||
toast(t('State model mismatch'), { type: 'error' });
|
||||
else if (error.includes('file format of the model or state model not supported'))
|
||||
toast(t('File format of the model or state model not supported'), { type: 'error' });
|
||||
else
|
||||
toast(error, { type: 'error' });
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -7,11 +7,13 @@ import {
|
||||
Dropdown,
|
||||
Input,
|
||||
Label,
|
||||
Link,
|
||||
Option,
|
||||
PresenceBadge,
|
||||
Select,
|
||||
Switch,
|
||||
Text
|
||||
Text,
|
||||
Tooltip
|
||||
} from '@fluentui/react-components';
|
||||
import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons';
|
||||
import React, { FC, useCallback, useEffect, useRef } from 'react';
|
||||
@@ -27,7 +29,7 @@ import { Page } from '../components/Page';
|
||||
import { useNavigate } from 'react-router';
|
||||
import { RunButton } from '../components/RunButton';
|
||||
import { updateConfig } from '../apis';
|
||||
import { getStrategy } from '../utils';
|
||||
import { getStrategy, isDynamicStateSupported } from '../utils';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import strategyImg from '../assets/images/strategy.jpg';
|
||||
import strategyZhImg from '../assets/images/strategy_zh.jpg';
|
||||
@@ -36,6 +38,7 @@ import { useMediaQuery } from 'usehooks-ts';
|
||||
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
|
||||
import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
|
||||
import { defaultPenaltyDecay } from './defaultConfigs';
|
||||
import { BrowserOpenURL } from '../../wailsjs/runtime';
|
||||
|
||||
const ConfigSelector: FC<{
|
||||
selectedIndex: number,
|
||||
@@ -112,6 +115,8 @@ const Configs: FC = observer(() => {
|
||||
|
||||
const onClickSave = () => {
|
||||
commonStore.setModelConfig(selectedIndex, selectedConfig);
|
||||
// When clicking RunButton in Configs page, updateConfig will be called twice,
|
||||
// because there are also RunButton in other pages, and the calls to updateConfig in both places are necessary.
|
||||
updateConfig({
|
||||
max_tokens: selectedConfig.apiParameters.maxResponseToken,
|
||||
temperature: selectedConfig.apiParameters.temperature,
|
||||
@@ -119,7 +124,18 @@ const Configs: FC = observer(() => {
|
||||
presence_penalty: selectedConfig.apiParameters.presencePenalty,
|
||||
frequency_penalty: selectedConfig.apiParameters.frequencyPenalty,
|
||||
penalty_decay: selectedConfig.apiParameters.penaltyDecay,
|
||||
global_penalty: selectedConfig.apiParameters.globalPenalty
|
||||
global_penalty: selectedConfig.apiParameters.globalPenalty,
|
||||
state: selectedConfig.apiParameters.stateModel
|
||||
}).then(async r => {
|
||||
if (r.status !== 200) {
|
||||
const error = await r.text();
|
||||
if (error.includes('state shape mismatch'))
|
||||
toast(t('State model mismatch'), { type: 'error' });
|
||||
else if (error.includes('file format of the model or state model not supported'))
|
||||
toast(t('File format of the model or state model not supported'), { type: 'error' });
|
||||
else
|
||||
toast(error, { type: 'error' });
|
||||
}
|
||||
});
|
||||
toast(t('Config Saved'), { autoClose: 300, type: 'success' });
|
||||
};
|
||||
@@ -200,6 +216,34 @@ const Configs: FC = observer(() => {
|
||||
});
|
||||
}} />
|
||||
} />
|
||||
{isDynamicStateSupported(selectedConfig) &&
|
||||
<div className="sm:col-span-2 flex gap-2 items-center min-w-0">
|
||||
<Tooltip content={<div>
|
||||
{t('State-tuned Model')}, {t('See More')}: <Link
|
||||
onClick={() => BrowserOpenURL('https://github.com/BlinkDL/RWKV-LM#state-tuning-tuning-the-initial-state-zero-inference-overhead')}>{'https://github.com/BlinkDL/RWKV-LM#state-tuning-tuning-the-initial-state-zero-inference-overhead'}
|
||||
</Link>
|
||||
</div>} showDelay={0} hideDelay={0}
|
||||
relationship="description">
|
||||
<div className="shrink-0">
|
||||
{t('State Model') + ' *'}
|
||||
</div>
|
||||
</Tooltip>
|
||||
<Select style={{ minWidth: 0 }} className="grow"
|
||||
value={selectedConfig.apiParameters.stateModel}
|
||||
onChange={(e, data) => {
|
||||
setSelectedConfigApiParams({
|
||||
stateModel: data.value
|
||||
});
|
||||
}}>
|
||||
<option key={-1} value={''}>
|
||||
{t('None')}
|
||||
</option>
|
||||
{commonStore.stateModels.map((modelName, index) =>
|
||||
<option key={index} value={modelName}>{modelName}</option>
|
||||
)}
|
||||
</Select>
|
||||
</div>
|
||||
}
|
||||
<Accordion className="sm:col-span-2" collapsible
|
||||
openItems={!commonStore.apiParamsCollapsed && 'advanced'}
|
||||
onToggle={(e, data) => {
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
import commonStore, { MonitorData, Platform } from './stores/commonStore';
|
||||
import { FileExists, GetPlatform, ListDirFiles, ReadJson } from '../wailsjs/go/backend_golang/App';
|
||||
import { Cache, checkUpdate, downloadProgramFiles, LocalConfig, refreshLocalModels, refreshModels } from './utils';
|
||||
import {
|
||||
bytesToMb,
|
||||
Cache,
|
||||
checkUpdate,
|
||||
downloadProgramFiles,
|
||||
LocalConfig,
|
||||
refreshLocalModels,
|
||||
refreshModels
|
||||
} from './utils';
|
||||
import { getStatus } from './apis';
|
||||
import { EventsOn, WindowSetTitle } from '../wailsjs/runtime';
|
||||
import manifest from '../../manifest.json';
|
||||
@@ -29,6 +37,7 @@ export async function startup() {
|
||||
});
|
||||
initLocalModelsNotify();
|
||||
initLoraModels();
|
||||
initStateModels();
|
||||
initHardwareMonitor();
|
||||
initMidi();
|
||||
}
|
||||
@@ -124,12 +133,42 @@ async function initLoraModels() {
|
||||
});
|
||||
}
|
||||
|
||||
async function initStateModels() {
|
||||
const refreshStateModels = throttle(async () => {
|
||||
const stateModels = await ListDirFiles('state-models').then((data) => {
|
||||
if (!data) return [];
|
||||
const stateModels = [];
|
||||
for (const f of data) {
|
||||
if (!f.isDir && f.name.endsWith('.pth')) {
|
||||
stateModels.push('state-models/' + f.name);
|
||||
}
|
||||
}
|
||||
return stateModels;
|
||||
});
|
||||
await ListDirFiles('models').then((data) => {
|
||||
if (!data) return;
|
||||
for (const f of data) {
|
||||
if (!f.isDir && f.name.endsWith('.pth') && Number(bytesToMb(f.size)) < 200) {
|
||||
stateModels.push('models/' + f.name);
|
||||
}
|
||||
}
|
||||
});
|
||||
commonStore.setStateModels(stateModels);
|
||||
}, 2000);
|
||||
|
||||
refreshStateModels();
|
||||
EventsOn('fsnotify', (data: string) => {
|
||||
if ((data.includes('models') && !data.includes('lora-models')) || data.includes('state-models'))
|
||||
refreshStateModels();
|
||||
});
|
||||
}
|
||||
|
||||
async function initLocalModelsNotify() {
|
||||
const throttleRefreshLocalModels = throttle(() => {
|
||||
refreshLocalModels({ models: commonStore.modelSourceList }, false); //TODO fix bug that only add models
|
||||
}, 2000);
|
||||
EventsOn('fsnotify', (data: string) => {
|
||||
if (data.includes('models') && !data.includes('lora-models'))
|
||||
if (data.includes('models') && !data.includes('lora-models') && !data.includes('state-models'))
|
||||
throttleRefreshLocalModels();
|
||||
});
|
||||
}
|
||||
|
||||
@@ -65,6 +65,7 @@ class CommonStore {
|
||||
platform: Platform = 'windows';
|
||||
proxyPort: number = 0;
|
||||
lastModelName: string = '';
|
||||
stateModels: string[] = [];
|
||||
// presets manager
|
||||
editingPreset: Preset | null = null;
|
||||
presets: Preset[] = [];
|
||||
@@ -410,6 +411,10 @@ class CommonStore {
|
||||
this.loraModels = value;
|
||||
}
|
||||
|
||||
setStateModels(value: string[]) {
|
||||
this.stateModels = value;
|
||||
}
|
||||
|
||||
setAttachmentUploading(value: boolean) {
|
||||
this.attachmentUploading = value;
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ export type ApiParameters = {
|
||||
frequencyPenalty: number;
|
||||
penaltyDecay?: number;
|
||||
globalPenalty?: boolean;
|
||||
stateModel?: string;
|
||||
}
|
||||
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'WebGPU (Python)' | 'MPS' | 'Custom';
|
||||
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';
|
||||
|
||||
@@ -676,4 +676,11 @@ export function newChatConversation() {
|
||||
commonStore.setConversationOrder(conversationOrder);
|
||||
};
|
||||
return { pushMessage, saveConversation };
|
||||
}
|
||||
|
||||
export function isDynamicStateSupported(modelConfig: ModelConfig) {
|
||||
return modelConfig.modelParameters.device === 'CUDA' ||
|
||||
modelConfig.modelParameters.device === 'CPU' ||
|
||||
modelConfig.modelParameters.device === 'Custom' ||
|
||||
modelConfig.modelParameters.device === 'MPS';
|
||||
}
|
||||
Reference in New Issue
Block a user