add support for dynamic state-tuned models

This commit is contained in:
josc146 2024-05-12 21:51:24 +08:00
parent b52873cb37
commit a2bbbabee2
12 changed files with 230 additions and 15 deletions

View File

@ -125,6 +125,7 @@ func (a *App) OnStartup(ctx context.Context) {
os.Chmod(a.exDir+"backend-rust/web-rwkv-converter", 0777)
os.Mkdir(a.exDir+"models", os.ModePerm)
os.Mkdir(a.exDir+"lora-models", os.ModePerm)
os.Mkdir(a.exDir+"state-models", os.ModePerm)
os.Mkdir(a.exDir+"finetune/json2binidx_tool/data", os.ModePerm)
trainLogPath := "lora-models/train_log.txt"
if !a.FileExists(trainLogPath) {
@ -151,8 +152,9 @@ func (a *App) OnBeforeClose(ctx context.Context) bool {
func (a *App) watchFs() {
watcher, err := fsnotify.NewWatcher()
if err == nil {
watcher.Add(a.exDir + "./lora-models")
watcher.Add(a.exDir + "./models")
watcher.Add(a.exDir + "./lora-models")
watcher.Add(a.exDir + "./state-models")
go func() {
for {
select {

View File

@ -120,6 +120,9 @@ def update_config(body: ModelConfigBody):
model_config = ModelConfigBody()
global_var.set(global_var.Model_Config, model_config)
merge_model(model_config, body)
exception = load_rwkv_state(global_var.get(global_var.Model), model_config.state)
if exception is not None:
raise exception
print("Updated Model Config:", model_config)
return "success"

View File

@ -176,6 +176,19 @@ def reset_state():
return "success"
def force_reset_state():
global trie, dtrie
if trie is None:
return
import cyac
trie = cyac.Trie()
dtrie = {}
gc.collect()
class LongestPrefixStateBody(BaseModel):
prompt: str

View File

@ -7,7 +7,7 @@ import re
import time
from typing import Dict, Iterable, List, Tuple, Union, Type, Callable
from utils.log import quick_log
from fastapi import HTTPException
from fastapi import HTTPException, status
from pydantic import BaseModel, Field
from routes import state_cache
import global_var
@ -27,6 +27,7 @@ class AbstractRWKV(ABC):
self.EOS_ID = 0
self.name = "rwkv"
self.model_path = ""
self.version = 4
self.model = model
self.pipeline = pipeline
@ -43,6 +44,8 @@ class AbstractRWKV(ABC):
self.penalty_alpha_frequency = 1
self.penalty_decay = 0.996
self.global_penalty = False
self.state_path = ""
self.state_tuned = None
@abstractmethod
def adjust_occurrence(self, occurrence: Dict, token: int):
@ -236,7 +239,10 @@ class AbstractRWKV(ABC):
except HTTPException:
pass
if cache is None or cache["prompt"] == "" or cache["state"] is None:
self.model_state = None
if self.state_path:
self.model_state = copy.deepcopy(self.state_tuned)
else:
self.model_state = None
self.model_tokens = []
else:
delta_prompt = prompt[len(cache["prompt"]) :]
@ -606,13 +612,13 @@ def get_model_path(model_path: str) -> str:
def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV:
model = get_model_path(model)
model_path = get_model_path(model)
rwkv_beta = global_var.get(global_var.Args).rwkv_beta
rwkv_cpp = getattr(global_var.get(global_var.Args), "rwkv.cpp")
webgpu = global_var.get(global_var.Args).webgpu
if "midi" in model.lower() or "abc" in model.lower():
if "midi" in model_path.lower() or "abc" in model_path.lower():
os.environ["RWKV_RESCALE_LAYER"] = "999"
# dynamic import to make RWKV_CUDA_ON work
@ -637,8 +643,8 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
)
from rwkv_pip.utils import PIPELINE
filename, _ = os.path.splitext(os.path.basename(model))
model = Model(model, strategy)
filename, _ = os.path.splitext(os.path.basename(model_path))
model = Model(model_path, strategy)
if not tokenizer:
tokenizer = get_tokenizer(len(model.w["emb.weight"]))
pipeline = PIPELINE(model, tokenizer)
@ -671,6 +677,7 @@ def RWKV(model: str, strategy: str, tokenizer: Union[str, None]) -> AbstractRWKV
else:
rwkv = TextRWKV(model, pipeline)
rwkv.name = filename
rwkv.model_path = model_path
rwkv.version = model.version
return rwkv
@ -688,6 +695,7 @@ class ModelConfigBody(BaseModel):
default=None,
description="When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.",
)
state: str = Field(default=None, description="state-tuned file path")
model_config = {
"json_schema_extra": {
@ -699,11 +707,80 @@ class ModelConfigBody(BaseModel):
"frequency_penalty": 1,
"penalty_decay": 0.996,
"global_penalty": False,
"state": "",
}
}
}
def load_rwkv_state(model: AbstractRWKV, state_path: str) -> HTTPException:
if model:
if state_path:
if model.model_path.endswith(".pth") and state_path.endswith(".pth"):
import torch
state_path = get_model_path(state_path)
if model.state_path == state_path:
return
state_raw = torch.load(state_path, map_location="cpu")
state_raw_shape = next(iter(state_raw.values())).shape
args = model.model.args
if (
len(state_raw) != args.n_layer
or state_raw_shape[0] * state_raw_shape[1] != args.n_embd
):
if model.state_path:
pass
else:
print("state failed to load")
return HTTPException(
status.HTTP_400_BAD_REQUEST, "state shape mismatch"
)
strategy = model.model.strategy
model.state_tuned = [None] * args.n_layer * 3
for i in range(args.n_layer):
dd = strategy[i]
dev = dd.device
atype = dd.atype
model.state_tuned[i * 3 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
model.state_tuned[i * 3 + 1] = (
state_raw[f"blocks.{i}.att.time_state"]
.transpose(1, 2)
.to(dtype=torch.float, device=dev)
.requires_grad_(False)
.contiguous()
)
model.state_tuned[i * 3 + 2] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
state_cache.force_reset_state()
model.state_path = state_path
print("state loaded")
else:
if model.state_path:
pass
else:
print("state failed to load")
return HTTPException(
status.HTTP_400_BAD_REQUEST,
"file format of the model or state model not supported",
)
else:
state_cache.force_reset_state()
model.state_path = ""
model.state_tuned = None # TODO cached
print("state unloaded")
else:
print("state not loaded")
def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
if body.max_tokens is not None:
model.max_tokens_per_generation = body.max_tokens
@ -724,6 +801,8 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
model.top_k = body.top_k
if body.global_penalty is not None:
model.global_penalty = body.global_penalty
if body.state is not None:
load_rwkv_state(model, body.state)
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
@ -736,4 +815,5 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
penalty_decay=model.penalty_decay,
top_k=model.top_k,
global_penalty=model.global_penalty,
state=model.state_path,
)

View File

@ -354,5 +354,10 @@
"Inside the model, there is a default prompt to improve the model's handling of common issues, but it may degrade the role-playing effect. You can disable this option to achieve a better role-playing effect.": "モデル内部には、一般的な問題の処理を改善するためのデフォルトのプロンプトがありますが、役割演技の効果を低下させる可能性があります。このオプションを無効にすることで、より良い役割演技効果を得ることができます。",
"Exit without saving": "保存せずに終了",
"Content has been changed, are you sure you want to exit without saving?": "コンテンツが変更されています、保存せずに終了してもよろしいですか?",
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "Ollama APIチャットモデル名を正しく記入するのを忘れないでください。"
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "Ollama APIチャットモデル名を正しく記入するのを忘れないでください。",
"State-tuned Model": "State調整モデル",
"See More": "もっと見る",
"State Model": "Stateモデル",
"State model mismatch": "Stateモデルの不一致",
"File format of the model or state model not supported": "モデルまたはStateモデルのファイル形式がサポートされていません"
}

View File

@ -354,5 +354,10 @@
"Inside the model, there is a default prompt to improve the model's handling of common issues, but it may degrade the role-playing effect. You can disable this option to achieve a better role-playing effect.": "模型内部有一个默认提示来改善模型处理常规问题的效果, 但它可能会让角色扮演的效果变差, 你可以关闭此选项来获得更好的角色扮演效果",
"Exit without saving": "退出而不保存",
"Content has been changed, are you sure you want to exit without saving?": "内容已经被修改, 你确定要退出而不保存吗?",
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "不要忘记正确填写你的Ollama API 聊天模型名"
"Don't forget to correctly fill in your Ollama API Chat Model Name.": "不要忘记正确填写你的Ollama API 聊天模型名",
"State-tuned Model": "State微调模型",
"See More": "查看更多",
"State Model": "State模型",
"State model mismatch": "State模型不匹配",
"File format of the model or state model not supported": "模型或state模型的文件格式不支持"
}

View File

@ -214,7 +214,18 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
presence_penalty: modelConfig.apiParameters.presencePenalty,
frequency_penalty: modelConfig.apiParameters.frequencyPenalty,
penalty_decay: modelConfig.apiParameters.penaltyDecay,
global_penalty: modelConfig.apiParameters.globalPenalty
global_penalty: modelConfig.apiParameters.globalPenalty,
state: modelConfig.apiParameters.stateModel
}).then(async r => {
if (r.status !== 200) {
const error = await r.text();
if (error.includes('state shape mismatch'))
toast(t('State model mismatch'), { type: 'error' });
else if (error.includes('file format of the model or state model not supported'))
toast(t('File format of the model or state model not supported'), { type: 'error' });
else
toast(error, { type: 'error' });
}
});
}

View File

@ -7,11 +7,13 @@ import {
Dropdown,
Input,
Label,
Link,
Option,
PresenceBadge,
Select,
Switch,
Text
Text,
Tooltip
} from '@fluentui/react-components';
import { AddCircle20Regular, DataUsageSettings20Regular, Delete20Regular, Save20Regular } from '@fluentui/react-icons';
import React, { FC, useCallback, useEffect, useRef } from 'react';
@ -27,7 +29,7 @@ import { Page } from '../components/Page';
import { useNavigate } from 'react-router';
import { RunButton } from '../components/RunButton';
import { updateConfig } from '../apis';
import { getStrategy } from '../utils';
import { getStrategy, isDynamicStateSupported } from '../utils';
import { useTranslation } from 'react-i18next';
import strategyImg from '../assets/images/strategy.jpg';
import strategyZhImg from '../assets/images/strategy_zh.jpg';
@ -36,6 +38,7 @@ import { useMediaQuery } from 'usehooks-ts';
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
import { defaultPenaltyDecay } from './defaultConfigs';
import { BrowserOpenURL } from '../../wailsjs/runtime';
const ConfigSelector: FC<{
selectedIndex: number,
@ -112,6 +115,8 @@ const Configs: FC = observer(() => {
const onClickSave = () => {
commonStore.setModelConfig(selectedIndex, selectedConfig);
// When clicking RunButton in Configs page, updateConfig will be called twice,
// because there are also RunButton in other pages, and the calls to updateConfig in both places are necessary.
updateConfig({
max_tokens: selectedConfig.apiParameters.maxResponseToken,
temperature: selectedConfig.apiParameters.temperature,
@ -119,7 +124,18 @@ const Configs: FC = observer(() => {
presence_penalty: selectedConfig.apiParameters.presencePenalty,
frequency_penalty: selectedConfig.apiParameters.frequencyPenalty,
penalty_decay: selectedConfig.apiParameters.penaltyDecay,
global_penalty: selectedConfig.apiParameters.globalPenalty
global_penalty: selectedConfig.apiParameters.globalPenalty,
state: selectedConfig.apiParameters.stateModel
}).then(async r => {
if (r.status !== 200) {
const error = await r.text();
if (error.includes('state shape mismatch'))
toast(t('State model mismatch'), { type: 'error' });
else if (error.includes('file format of the model or state model not supported'))
toast(t('File format of the model or state model not supported'), { type: 'error' });
else
toast(error, { type: 'error' });
}
});
toast(t('Config Saved'), { autoClose: 300, type: 'success' });
};
@ -200,6 +216,34 @@ const Configs: FC = observer(() => {
});
}} />
} />
{isDynamicStateSupported(selectedConfig) &&
<div className="sm:col-span-2 flex gap-2 items-center min-w-0">
<Tooltip content={<div>
{t('State-tuned Model')}, {t('See More')}: <Link
onClick={() => BrowserOpenURL('https://github.com/BlinkDL/RWKV-LM#state-tuning-tuning-the-initial-state-zero-inference-overhead')}>{'https://github.com/BlinkDL/RWKV-LM#state-tuning-tuning-the-initial-state-zero-inference-overhead'}
</Link>
</div>} showDelay={0} hideDelay={0}
relationship="description">
<div className="shrink-0">
{t('State Model') + ' *'}
</div>
</Tooltip>
<Select style={{ minWidth: 0 }} className="grow"
value={selectedConfig.apiParameters.stateModel}
onChange={(e, data) => {
setSelectedConfigApiParams({
stateModel: data.value
});
}}>
<option key={-1} value={''}>
{t('None')}
</option>
{commonStore.stateModels.map((modelName, index) =>
<option key={index} value={modelName}>{modelName}</option>
)}
</Select>
</div>
}
<Accordion className="sm:col-span-2" collapsible
openItems={!commonStore.apiParamsCollapsed && 'advanced'}
onToggle={(e, data) => {

View File

@ -1,6 +1,14 @@
import commonStore, { MonitorData, Platform } from './stores/commonStore';
import { FileExists, GetPlatform, ListDirFiles, ReadJson } from '../wailsjs/go/backend_golang/App';
import { Cache, checkUpdate, downloadProgramFiles, LocalConfig, refreshLocalModels, refreshModels } from './utils';
import {
bytesToMb,
Cache,
checkUpdate,
downloadProgramFiles,
LocalConfig,
refreshLocalModels,
refreshModels
} from './utils';
import { getStatus } from './apis';
import { EventsOn, WindowSetTitle } from '../wailsjs/runtime';
import manifest from '../../manifest.json';
@ -29,6 +37,7 @@ export async function startup() {
});
initLocalModelsNotify();
initLoraModels();
initStateModels();
initHardwareMonitor();
initMidi();
}
@ -124,12 +133,42 @@ async function initLoraModels() {
});
}
async function initStateModels() {
const refreshStateModels = throttle(async () => {
const stateModels = await ListDirFiles('state-models').then((data) => {
if (!data) return [];
const stateModels = [];
for (const f of data) {
if (!f.isDir && f.name.endsWith('.pth')) {
stateModels.push('state-models/' + f.name);
}
}
return stateModels;
});
await ListDirFiles('models').then((data) => {
if (!data) return;
for (const f of data) {
if (!f.isDir && f.name.endsWith('.pth') && Number(bytesToMb(f.size)) < 200) {
stateModels.push('models/' + f.name);
}
}
});
commonStore.setStateModels(stateModels);
}, 2000);
refreshStateModels();
EventsOn('fsnotify', (data: string) => {
if ((data.includes('models') && !data.includes('lora-models')) || data.includes('state-models'))
refreshStateModels();
});
}
async function initLocalModelsNotify() {
const throttleRefreshLocalModels = throttle(() => {
refreshLocalModels({ models: commonStore.modelSourceList }, false); //TODO fix bug that only add models
}, 2000);
EventsOn('fsnotify', (data: string) => {
if (data.includes('models') && !data.includes('lora-models'))
if (data.includes('models') && !data.includes('lora-models') && !data.includes('state-models'))
throttleRefreshLocalModels();
});
}

View File

@ -65,6 +65,7 @@ class CommonStore {
platform: Platform = 'windows';
proxyPort: number = 0;
lastModelName: string = '';
stateModels: string[] = [];
// presets manager
editingPreset: Preset | null = null;
presets: Preset[] = [];
@ -410,6 +411,10 @@ class CommonStore {
this.loraModels = value;
}
setStateModels(value: string[]) {
this.stateModels = value;
}
setAttachmentUploading(value: boolean) {
this.attachmentUploading = value;
}

View File

@ -7,6 +7,7 @@ export type ApiParameters = {
frequencyPenalty: number;
penaltyDecay?: number;
globalPenalty?: boolean;
stateModel?: string;
}
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'WebGPU (Python)' | 'MPS' | 'Custom';
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';

View File

@ -677,3 +677,10 @@ export function newChatConversation() {
};
return { pushMessage, saveConversation };
}
export function isDynamicStateSupported(modelConfig: ModelConfig) {
return modelConfig.modelParameters.device === 'CUDA' ||
modelConfig.modelParameters.device === 'CPU' ||
modelConfig.modelParameters.device === 'Custom' ||
modelConfig.modelParameters.device === 'MPS';
}