add support for dynamic state-tuned models
This commit is contained in:
parent
b52873cb37
commit
a2bbbabee2
@ -125,6 +125,7 @@ func (a *App) OnStartup(ctx context.Context) {
|
||||
os.Chmod(a.exDir+"backend-rust/web-rwkv-converter", 0777)
|
||||
os.Mkdir(a.exDir+"models", os.ModePerm)
|
||||
os.Mkdir(a.exDir+"lora-models", os.ModePerm)
|
||||
os.Mkdir(a.exDir+"state-models", os.ModePerm)
|
||||
os.Mkdir(a.exDir+"finetune/json2binidx_tool/data", os.ModePerm)
|
||||
trainLogPath := "lora-models/train_log.txt"
|
||||
if !a.FileExists(trainLogPath) {
|
||||
@ -151,8 +152,9 @@ func (a *App) OnBeforeClose(ctx context.Context) bool {
|
||||
func (a *App) watchFs() {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err == nil {
|
||||
watcher.Add(a.exDir + "./lora-models")
|
||||
watcher.Add(a.exDir + "./models")
|
||||
watcher.Add(a.exDir + "./lora-models")
|
||||
watcher.Add(a.exDir + "./state-models")
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
|
@ -120,6 +120,9 @@ def update_config(body: ModelConfigBody):
|
||||
model_config = ModelConfigBody()
|
||||
global_var.set(global_var.Model_Config, model_config)
|
||||
merge_model(model_config, body)
|
||||
exception = load_rwkv_state(global_var.get(global_var.Model), model_config.state)
|
||||
if exception is not None:
|
||||
raise exception
|
||||
print("Updated Model Config:", model_config)
|
||||
|
||||
return "success"
|
||||
|
@ -176,6 +176,19 @@ def reset_state():
|
||||
return "success"
|
||||
|
||||
|
||||
def force_reset_state():
|
||||
global trie, dtrie
|
||||
|
||||
if trie is None:
|
||||
return
|
||||
|
||||
import cyac
|
||||
|
||||
trie = cyac.Trie()
|
||||
dtrie = {}
|
||||
gc.collect()
|
||||
|
||||
|
||||
class LongestPrefixStateBody(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
@ -7,7 +7,7 @@ import re
|
||||
import time
|
||||
from typing import Dict, Iterable, List, Tuple, Union, Type, Callable
|
||||
from utils.log import quick_log
|
||||
from fastapi import HTTPException
|
||||
from fastapi import HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from routes import state_cache
|
||||
import global_var
|
||||
@ -27,6 +27,7 @@ class AbstractRWKV(ABC):
|
||||
self.EOS_ID = 0
|
||||
|
||||
self.name = "rwkv"
|
||||
self.model_path = ""
|
||||
self.version = 4
|
||||
self.model = model
|
||||
self.pipeline = pipeline
|
||||
@ -43,6 +44,8 @@ class AbstractRWKV(ABC):
|
||||
self.penalty_alpha_frequency = 1
|
||||
self.penalty_decay = 0.996
|
||||
self.global_penalty = False
|
||||
self.state_path = ""
|
||||
self.state_tuned = None
|
||||
|
||||
@abstractmethod
|
||||
def adjust_occurrence(self, occurrence: Dict, token: int):
|
||||
@ -236,7 +239,10 @@ class AbstractRWKV(ABC):
|
||||
except HTTPException:
|
||||
pass
|
||||
if cache is None or cache["prompt"] == "" or cache["state"] is None:
|
||||
self.model_state = None
|
||||
if self.state_path:
|
||||
self.model_state = copy.deepcopy(self.state_tuned)
|
||||
else:
|
||||
self.model_state = None
|
||||
self.model_tokens = []
|
||||
else:
|
||||
delta_prompt = prompt[len(cache["prompt"]) :]
|
||||
@ -606,13 +612,13 @@ def get_model_path(model_path: str) -> str:
|
||||
|
||||
|
||||
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
|
||||
model = get_model_path(model)
|
||||
model_path = get_model_path(model)
|
||||
|
||||
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
|
||||
rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
|
||||
webgpu = global_var.get(global_var.Args).webgpu
|
||||
|
||||
if "midi" in model.lower() or "abc" in model.lower():
|
||||
if "midi" in model_path.lower() or "abc" in model_path.lower():
|
||||
os.environ["RWKV_RESCALE_LAYER"] = "999"
|
||||
|
||||
# dynamic import to make RWKV_CUDA_ON work
|
||||
@ -637,8 +643,8 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
|
||||
)
|
||||
from rwkv_pip.utils import PIPELINE
|
||||
|
||||
filename, _ = os.path.splitext(os.path.basename(model))
|
||||
model = Model(model, strategy)
|
||||
filename, _ = os.path.splitext(os.path.basename(model_path))
|
||||
model = Model(model_path, strategy)
|
||||
if not tokenizer:
|
||||
tokenizer = get_tokenizer(len(model.w["emb.weight"]))
|
||||
pipeline = PIPELINE(model, tokenizer)
|
||||
@ -671,6 +677,7 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
|
||||
else:
|
||||
rwkv = TextRWKV(model, pipeline)
|
||||
rwkv.name = filename
|
||||
rwkv.model_path = model_path
|
||||
rwkv.version = model.version
|
||||
|
||||
return rwkv
|
||||
@ -688,6 +695,7 @@ class ModelConfigBody(BaseModel):
|
||||
default=None,
|
||||
description="When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.",
|
||||
)
|
||||
state: str = Field(default=None, description="state-tuned file path")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
@ -699,11 +707,80 @@ class ModelConfigBody(BaseModel):
|
||||
"frequency_penalty": 1,
|
||||
"penalty_decay": 0.996,
|
||||
"global_penalty": False,
|
||||
"state": "",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
|
||||
if model:
|
||||
if state_path:
|
||||
if model.model_path.endswith(".pth") and state_path.endswith(".pth"):
|
||||
import torch
|
||||
|
||||
state_path = get_model_path(state_path)
|
||||
if model.state_path == state_path:
|
||||
return
|
||||
|
||||
state_raw = torch.load(state_path, map_location="cpu")
|
||||
state_raw_shape = next(iter(state_raw.values())).shape
|
||||
|
||||
args = model.model.args
|
||||
if (
|
||||
len(state_raw) != args.n_layer
|
||||
or state_raw_shape[0] * state_raw_shape[1] != args.n_embd
|
||||
):
|
||||
if model.state_path:
|
||||
pass
|
||||
else:
|
||||
print("state failed to load")
|
||||
return HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, "state shape mismatch"
|
||||
)
|
||||
|
||||
strategy = model.model.strategy
|
||||
model.state_tuned = [None] * args.n_layer * 3
|
||||
|
||||
for i in range(args.n_layer):
|
||||
dd = strategy[i]
|
||||
dev = dd.device
|
||||
atype = dd.atype
|
||||
model.state_tuned[i * 3 + 0] = torch.zeros(
|
||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||
).contiguous()
|
||||
model.state_tuned[i * 3 + 1] = (
|
||||
state_raw[f"blocks.{i}.att.time_state"]
|
||||
.transpose(1, 2)
|
||||
.to(dtype=torch.float, device=dev)
|
||||
.requires_grad_(False)
|
||||
.contiguous()
|
||||
)
|
||||
model.state_tuned[i * 3 + 2] = torch.zeros(
|
||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||
).contiguous()
|
||||
|
||||
state_cache.force_reset_state()
|
||||
model.state_path = state_path
|
||||
print("state loaded")
|
||||
else:
|
||||
if model.state_path:
|
||||
pass
|
||||
else:
|
||||
print("state failed to load")
|
||||
return HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
"file format of the model or state model not supported",
|
||||
)
|
||||
else:
|
||||
state_cache.force_reset_state()
|
||||
model.state_path = ""
|
||||
model.state_tuned = None # TODO cached
|
||||
print("state unloaded")
|
||||
else:
|
||||
print("state not loaded")
|
||||
|
||||
|
||||
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
|
||||
if body.max_tokens is not None:
|
||||
model.max_tokens_per_generation = body.max_tokens
|
||||
@ -724,6 +801,8 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
|
||||
model.top_k = body.top_k
|
||||
if body.global_penalty is not None:
|
||||
model.global_penalty = body.global_penalty
|
||||
if body.state is not None:
|
||||
load_rwkv_state(model, body.state)
|
||||
|
||||
|
||||
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
||||
@ -736,4 +815,5 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
||||
penalty_decay=model.penalty_decay,
|
||||
top_k=model.top_k,
|
||||
global_penalty=model.global_penalty,
|
||||
state=model.state_path,
|
||||
)
|
||||
|
@ -354,5 +354,10 @@
|
||||
"Inside the model, there is a default prompt to improve the model's handling of common issues, but it may degrade the role-playing effect. You can disable this option to achieve a better role-playing effect.": "モデル内部には、一般的な問題の処理を改善するためのデフォルトのプロンプトがありますが、役割演技の効果を低下させる可能性があります。このオプションを無効にすることで、より良い役割演技効果を得ることができます。",
|
||||
"Exit without saving": "保存せずに終了",
|
||||
"Content has been changed, are you sure you want to exit without saving?": "コンテンツが変更されています、保存せずに終了してもよろしいですか?",
|
||||
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "Ollama APIチャットモデル名を正しく記入するのを忘れないでください。"
|
||||
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "Ollama APIチャットモデル名を正しく記入するのを忘れないでください。",
|
||||
"State-tuned Model": "State調整モデル",
|
||||
"See More": "もっと見る",
|
||||
"State Model": "Stateモデル",
|
||||
"State model mismatch": "Stateモデルの不一致",
|
||||
"File format of the model or state model not supported": "モデルまたはStateモデルのファイル形式がサポートされていません"
|
||||
}
|
@ -354,5 +354,10 @@
|
||||
"Inside the model, there is a default prompt to improve the model's handling of common issues, but it may degrade the role-playing effect. You can disable this option to achieve a better role-playing effect.": "模型内部有一个默认提示来改善模型处理常规问题的效果, 但它可能会让角色扮演的效果变差, 你可以关闭此选项来获得更好的角色扮演效果",
|
||||
"Exit without saving": "退出而不保存",
|
||||
"Content has been changed, are you sure you want to exit without saving?": "内容已经被修改, 你确定要退出而不保存吗?",
|
||||
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "不要忘记正确填写你的Ollama API 聊天模型名"
|
||||
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "不要忘记正确填写你的Ollama API 聊天模型名",
|
||||
"State-tuned Model": "State微调模型",
|
||||
"See More": "查看更多",
|
||||
"State Model": "State模型",
|
||||
"State model mismatch": "State模型不匹配",
|
||||
"File format of the model or state model not supported": "模型或state模型的文件格式不支持"
|
||||
}
|
@ -214,7 +214,18 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
presence_penalty: modelConfig.apiParameters.presencePenalty,
|
||||
frequency_penalty: modelConfig.apiParameters.frequencyPenalty,
|
||||
penalty_decay: modelConfig.apiParameters.penaltyDecay,
|
||||
global_penalty: modelConfig.apiParameters.globalPenalty
|
||||
global_penalty: modelConfig.apiParameters.globalPenalty,
|
||||
state: modelConfig.apiParameters.stateModel
|
||||
}).then(async r => {
|
||||
if (r.status !== 200) {
|
||||
const error = await r.text();
|
||||
if (error.includes('state shape mismatch'))
|
||||
toast(t('State model mismatch'), { type: 'error' });
|
||||
else if (error.includes('file format of the model or state model not supported'))
|
||||
toast(t('File format of the model or state model not supported'), { type: 'error' });
|
||||
else
|
||||
toast(error, { type: 'error' });
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -7,11 +7,13 @@ import {
|
||||
Dropdown,
|
||||
Input,
|
||||
Label,
|
||||
Link,
|
||||
Option,
|
||||
PresenceBadge,
|
||||
Select,
|
||||
Switch,
|
||||
Text
|
||||
Text,
|
||||
Tooltip
|
||||
} from '@fluentui/react-components';
|
||||
import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons';
|
||||
import React, { FC, useCallback, useEffect, useRef } from 'react';
|
||||
@ -27,7 +29,7 @@ import { Page } from '../components/Page';
|
||||
import { useNavigate } from 'react-router';
|
||||
import { RunButton } from '../components/RunButton';
|
||||
import { updateConfig } from '../apis';
|
||||
import { getStrategy } from '../utils';
|
||||
import { getStrategy, isDynamicStateSupported } from '../utils';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import strategyImg from '../assets/images/strategy.jpg';
|
||||
import strategyZhImg from '../assets/images/strategy_zh.jpg';
|
||||
@ -36,6 +38,7 @@ import { useMediaQuery } from 'usehooks-ts';
|
||||
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
|
||||
import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
|
||||
import { defaultPenaltyDecay } from './defaultConfigs';
|
||||
import { BrowserOpenURL } from '../../wailsjs/runtime';
|
||||
|
||||
const ConfigSelector: FC<{
|
||||
selectedIndex: number,
|
||||
@ -112,6 +115,8 @@ const Configs: FC = observer(() => {
|
||||
|
||||
const onClickSave = () => {
|
||||
commonStore.setModelConfig(selectedIndex, selectedConfig);
|
||||
// When clicking RunButton in Configs page, updateConfig will be called twice,
|
||||
// because there are also RunButton in other pages, and the calls to updateConfig in both places are necessary.
|
||||
updateConfig({
|
||||
max_tokens: selectedConfig.apiParameters.maxResponseToken,
|
||||
temperature: selectedConfig.apiParameters.temperature,
|
||||
@ -119,7 +124,18 @@ const Configs: FC = observer(() => {
|
||||
presence_penalty: selectedConfig.apiParameters.presencePenalty,
|
||||
frequency_penalty: selectedConfig.apiParameters.frequencyPenalty,
|
||||
penalty_decay: selectedConfig.apiParameters.penaltyDecay,
|
||||
global_penalty: selectedConfig.apiParameters.globalPenalty
|
||||
global_penalty: selectedConfig.apiParameters.globalPenalty,
|
||||
state: selectedConfig.apiParameters.stateModel
|
||||
}).then(async r => {
|
||||
if (r.status !== 200) {
|
||||
const error = await r.text();
|
||||
if (error.includes('state shape mismatch'))
|
||||
toast(t('State model mismatch'), { type: 'error' });
|
||||
else if (error.includes('file format of the model or state model not supported'))
|
||||
toast(t('File format of the model or state model not supported'), { type: 'error' });
|
||||
else
|
||||
toast(error, { type: 'error' });
|
||||
}
|
||||
});
|
||||
toast(t('Config Saved'), { autoClose: 300, type: 'success' });
|
||||
};
|
||||
@ -200,6 +216,34 @@ const Configs: FC = observer(() => {
|
||||
});
|
||||
}} />
|
||||
} />
|
||||
{isDynamicStateSupported(selectedConfig) &&
|
||||
<div className="sm:col-span-2 flex gap-2 items-center min-w-0">
|
||||
<Tooltip content={<div>
|
||||
{t('State-tuned Model')}, {t('See More')}: <Link
|
||||
onClick={() => BrowserOpenURL('https://github.com/BlinkDL/RWKV-LM#state-tuning-tuning-the-initial-state-zero-inference-overhead')}>{'https://github.com/BlinkDL/RWKV-LM#state-tuning-tuning-the-initial-state-zero-inference-overhead'}
|
||||
</Link>
|
||||
</div>} showDelay={0} hideDelay={0}
|
||||
relationship="description">
|
||||
<div className="shrink-0">
|
||||
{t('State Model') + ' *'}
|
||||
</div>
|
||||
</Tooltip>
|
||||
<Select style={{ minWidth: 0 }} className="grow"
|
||||
value={selectedConfig.apiParameters.stateModel}
|
||||
onChange={(e, data) => {
|
||||
setSelectedConfigApiParams({
|
||||
stateModel: data.value
|
||||
});
|
||||
}}>
|
||||
<option key={-1} value={''}>
|
||||
{t('None')}
|
||||
</option>
|
||||
{commonStore.stateModels.map((modelName, index) =>
|
||||
<option key={index} value={modelName}>{modelName}</option>
|
||||
)}
|
||||
</Select>
|
||||
</div>
|
||||
}
|
||||
<Accordion className="sm:col-span-2" collapsible
|
||||
openItems={!commonStore.apiParamsCollapsed && 'advanced'}
|
||||
onToggle={(e, data) => {
|
||||
|
@ -1,6 +1,14 @@
|
||||
import commonStore, { MonitorData, Platform } from './stores/commonStore';
|
||||
import { FileExists, GetPlatform, ListDirFiles, ReadJson } from '../wailsjs/go/backend_golang/App';
|
||||
import { Cache, checkUpdate, downloadProgramFiles, LocalConfig, refreshLocalModels, refreshModels } from './utils';
|
||||
import {
|
||||
bytesToMb,
|
||||
Cache,
|
||||
checkUpdate,
|
||||
downloadProgramFiles,
|
||||
LocalConfig,
|
||||
refreshLocalModels,
|
||||
refreshModels
|
||||
} from './utils';
|
||||
import { getStatus } from './apis';
|
||||
import { EventsOn, WindowSetTitle } from '../wailsjs/runtime';
|
||||
import manifest from '../../manifest.json';
|
||||
@ -29,6 +37,7 @@ export async function startup() {
|
||||
});
|
||||
initLocalModelsNotify();
|
||||
initLoraModels();
|
||||
initStateModels();
|
||||
initHardwareMonitor();
|
||||
initMidi();
|
||||
}
|
||||
@ -124,12 +133,42 @@ async function initLoraModels() {
|
||||
});
|
||||
}
|
||||
|
||||
async function initStateModels() {
|
||||
const refreshStateModels = throttle(async () => {
|
||||
const stateModels = await ListDirFiles('state-models').then((data) => {
|
||||
if (!data) return [];
|
||||
const stateModels = [];
|
||||
for (const f of data) {
|
||||
if (!f.isDir && f.name.endsWith('.pth')) {
|
||||
stateModels.push('state-models/' + f.name);
|
||||
}
|
||||
}
|
||||
return stateModels;
|
||||
});
|
||||
await ListDirFiles('models').then((data) => {
|
||||
if (!data) return;
|
||||
for (const f of data) {
|
||||
if (!f.isDir && f.name.endsWith('.pth') && Number(bytesToMb(f.size)) < 200) {
|
||||
stateModels.push('models/' + f.name);
|
||||
}
|
||||
}
|
||||
});
|
||||
commonStore.setStateModels(stateModels);
|
||||
}, 2000);
|
||||
|
||||
refreshStateModels();
|
||||
EventsOn('fsnotify', (data: string) => {
|
||||
if ((data.includes('models') && !data.includes('lora-models')) || data.includes('state-models'))
|
||||
refreshStateModels();
|
||||
});
|
||||
}
|
||||
|
||||
async function initLocalModelsNotify() {
|
||||
const throttleRefreshLocalModels = throttle(() => {
|
||||
refreshLocalModels({ models: commonStore.modelSourceList }, false); //TODO fix bug that only add models
|
||||
}, 2000);
|
||||
EventsOn('fsnotify', (data: string) => {
|
||||
if (data.includes('models') && !data.includes('lora-models'))
|
||||
if (data.includes('models') && !data.includes('lora-models') && !data.includes('state-models'))
|
||||
throttleRefreshLocalModels();
|
||||
});
|
||||
}
|
||||
|
@ -65,6 +65,7 @@ class CommonStore {
|
||||
platform: Platform = 'windows';
|
||||
proxyPort: number = 0;
|
||||
lastModelName: string = '';
|
||||
stateModels: string[] = [];
|
||||
// presets manager
|
||||
editingPreset: Preset | null = null;
|
||||
presets: Preset[] = [];
|
||||
@ -410,6 +411,10 @@ class CommonStore {
|
||||
this.loraModels = value;
|
||||
}
|
||||
|
||||
setStateModels(value: string[]) {
|
||||
this.stateModels = value;
|
||||
}
|
||||
|
||||
setAttachmentUploading(value: boolean) {
|
||||
this.attachmentUploading = value;
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ export type ApiParameters = {
|
||||
frequencyPenalty: number;
|
||||
penaltyDecay?: number;
|
||||
globalPenalty?: boolean;
|
||||
stateModel?: string;
|
||||
}
|
||||
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'WebGPU (Python)' | 'MPS' | 'Custom';
|
||||
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';
|
||||
|
@ -676,4 +676,11 @@ export function newChatConversation() {
|
||||
commonStore.setConversationOrder(conversationOrder);
|
||||
};
|
||||
return { pushMessage, saveConversation };
|
||||
}
|
||||
|
||||
export function isDynamicStateSupported(modelConfig: ModelConfig) {
|
||||
return modelConfig.modelParameters.device === 'CUDA' ||
|
||||
modelConfig.modelParameters.device === 'CPU' ||
|
||||
modelConfig.modelParameters.device === 'Custom' ||
|
||||
modelConfig.modelParameters.device === 'MPS';
|
||||
}
|
Loading…
Reference in New Issue
Block a user