webgpu support
This commit is contained in:
parent
74f1a1c033
commit
ef53951a16
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -2,6 +2,7 @@ backend-python/rwkv_pip/** linguist-vendored
|
|||||||
backend-python/wkv_cuda_utils/** linguist-vendored
|
backend-python/wkv_cuda_utils/** linguist-vendored
|
||||||
backend-python/get-pip.py linguist-vendored
|
backend-python/get-pip.py linguist-vendored
|
||||||
backend-python/convert_model.py linguist-vendored
|
backend-python/convert_model.py linguist-vendored
|
||||||
|
backend-python/convert_safetensors.py linguist-vendored
|
||||||
backend-python/utils/midi.py linguist-vendored
|
backend-python/utils/midi.py linguist-vendored
|
||||||
build/** linguist-vendored
|
build/** linguist-vendored
|
||||||
finetune/lora/** linguist-vendored
|
finetune/lora/** linguist-vendored
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -5,6 +5,8 @@ __pycache__
|
|||||||
.idea
|
.idea
|
||||||
.vs
|
.vs
|
||||||
*.pth
|
*.pth
|
||||||
|
*.st
|
||||||
|
*.safetensors
|
||||||
*.bin
|
*.bin
|
||||||
/config.json
|
/config.json
|
||||||
/cache.json
|
/cache.json
|
||||||
|
@ -26,6 +26,13 @@ func (a *App) StartServer(python string, port int, host string, rwkvBeta bool) (
|
|||||||
return Cmd(args...)
|
return Cmd(args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *App) StartWebGPUServer(port int, host string) (string, error) {
|
||||||
|
args := []string{"./backend-rust/webgpu_server"}
|
||||||
|
args = append(args, "-a", "0", "-t", "backend-rust/assets/rwkv_vocab_v20230424.json",
|
||||||
|
"--port", strconv.Itoa(port), "--ip", host)
|
||||||
|
return Cmd(args...)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *App) ConvertModel(python string, modelPath string, strategy string, outPath string) (string, error) {
|
func (a *App) ConvertModel(python string, modelPath string, strategy string, outPath string) (string, error) {
|
||||||
var err error
|
var err error
|
||||||
if python == "" {
|
if python == "" {
|
||||||
@ -37,6 +44,17 @@ func (a *App) ConvertModel(python string, modelPath string, strategy string, out
|
|||||||
return Cmd(python, "./backend-python/convert_model.py", "--in", modelPath, "--out", outPath, "--strategy", strategy)
|
return Cmd(python, "./backend-python/convert_model.py", "--in", modelPath, "--out", outPath, "--strategy", strategy)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *App) ConvertSafetensors(python string, modelPath string, outPath string) (string, error) {
|
||||||
|
var err error
|
||||||
|
if python == "" {
|
||||||
|
python, err = GetPython()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return Cmd(python, "./backend-python/convert_safetensors.py", "--input", modelPath, "--output", outPath)
|
||||||
|
}
|
||||||
|
|
||||||
func (a *App) ConvertData(python string, input string, outputPrefix string, vocab string) (string, error) {
|
func (a *App) ConvertData(python string, input string, outputPrefix string, vocab string) (string, error) {
|
||||||
var err error
|
var err error
|
||||||
if python == "" {
|
if python == "" {
|
||||||
|
53
backend-python/convert_safetensors.py
vendored
Normal file
53
backend-python/convert_safetensors.py
vendored
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import copy
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--input", type=str, help="Path to input pth model")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=str,
|
||||||
|
default="./converted.st",
|
||||||
|
help="Path to output safetensors model",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_file(
|
||||||
|
pt_filename: str,
|
||||||
|
sf_filename: str,
|
||||||
|
):
|
||||||
|
loaded = torch.load(pt_filename, map_location="cpu")
|
||||||
|
if "state_dict" in loaded:
|
||||||
|
loaded = loaded["state_dict"]
|
||||||
|
|
||||||
|
loaded = {k: v.clone().half() for k, v in loaded.items()}
|
||||||
|
for k, v in loaded.items():
|
||||||
|
print(f"{k}\t{v.shape}\t{v.dtype}")
|
||||||
|
|
||||||
|
# For tensors to be contiguous
|
||||||
|
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||||
|
|
||||||
|
dirname = os.path.dirname(sf_filename)
|
||||||
|
os.makedirs(dirname, exist_ok=True)
|
||||||
|
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||||
|
reloaded = load_file(sf_filename)
|
||||||
|
for k in loaded:
|
||||||
|
pt_tensor = loaded[k]
|
||||||
|
sf_tensor = reloaded[k]
|
||||||
|
if not torch.equal(pt_tensor, sf_tensor):
|
||||||
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
convert_file(args.input, args.output)
|
||||||
|
print(f"Saved to {args.output}")
|
||||||
|
except Exception as e:
|
||||||
|
with open("error.txt", "w") as f:
|
||||||
|
f.write(str(e))
|
@ -1,3 +1,4 @@
|
|||||||
|
import safetensors
|
||||||
import midi2audio
|
import midi2audio
|
||||||
import mido
|
import mido
|
||||||
import lm_dataformat
|
import lm_dataformat
|
||||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
66861
backend-rust/assets/rwkv_vocab_v20230424.json
Normal file
66861
backend-rust/assets/rwkv_vocab_v20230424.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -241,5 +241,8 @@
|
|||||||
"Auto Play At The End": "最後に自動再生",
|
"Auto Play At The End": "最後に自動再生",
|
||||||
"No File to save": "保存するファイルがありません",
|
"No File to save": "保存するファイルがありません",
|
||||||
"File Saved": "ファイルが保存されました",
|
"File Saved": "ファイルが保存されました",
|
||||||
"Failed to load local sound font, please check if the files exist - assets/sound-font": "ローカルサウンドフォントの読み込みに失敗しました、ファイルが存在するか確認してください - assets/sound-font"
|
"Failed to load local sound font, please check if the files exist - assets/sound-font": "ローカルサウンドフォントの読み込みに失敗しました、ファイルが存在するか確認してください - assets/sound-font",
|
||||||
|
"Please convert model to safe tensors format first": "モデルを安全なテンソル形式に変換してください",
|
||||||
|
"Convert To Safe Tensors Format": "安全なテンソル形式に変換",
|
||||||
|
"Please change Strategy to WebGPU to use safetensors format": "StrategyをWebGPUに変更して、安全なテンソル形式を使用してください"
|
||||||
}
|
}
|
@ -241,5 +241,8 @@
|
|||||||
"Auto Play At The End": "结束时自动播放",
|
"Auto Play At The End": "结束时自动播放",
|
||||||
"No File to save": "无文件可保存",
|
"No File to save": "无文件可保存",
|
||||||
"File Saved": "文件已保存",
|
"File Saved": "文件已保存",
|
||||||
"Failed to load local sound font, please check if the files exist - assets/sound-font": "加载本地音色资源失败,请检查文件是否存在 - assets/sound-font"
|
"Failed to load local sound font, please check if the files exist - assets/sound-font": "加载本地音色资源失败,请检查文件是否存在 - assets/sound-font",
|
||||||
|
"Please convert model to safe tensors format first": "请先将模型转换为Safetensors格式",
|
||||||
|
"Convert To Safe Tensors Format": "转换为Safetensors格式",
|
||||||
|
"Please change Strategy to WebGPU to use safetensors format": "请将Strategy改为WebGPU以使用safetensors格式"
|
||||||
}
|
}
|
@ -1,6 +1,12 @@
|
|||||||
import React, { FC, MouseEventHandler, ReactElement } from 'react';
|
import React, { FC, MouseEventHandler, ReactElement } from 'react';
|
||||||
import commonStore, { ModelStatus } from '../stores/commonStore';
|
import commonStore, { ModelStatus } from '../stores/commonStore';
|
||||||
import { AddToDownloadList, CopyFile, FileExists, StartServer } from '../../wailsjs/go/backend_golang/App';
|
import {
|
||||||
|
AddToDownloadList,
|
||||||
|
CopyFile,
|
||||||
|
FileExists,
|
||||||
|
StartServer,
|
||||||
|
StartWebGPUServer
|
||||||
|
} from '../../wailsjs/go/backend_golang/App';
|
||||||
import { Button } from '@fluentui/react-components';
|
import { Button } from '@fluentui/react-components';
|
||||||
import { observer } from 'mobx-react-lite';
|
import { observer } from 'mobx-react-lite';
|
||||||
import { exit, getStatus, readRoot, switchModel, updateConfig } from '../apis';
|
import { exit, getStatus, readRoot, switchModel, updateConfig } from '../apis';
|
||||||
@ -39,6 +45,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
|||||||
commonStore.setStatus({ status: ModelStatus.Starting });
|
commonStore.setStatus({ status: ModelStatus.Starting });
|
||||||
|
|
||||||
const modelConfig = commonStore.getCurrentModelConfig();
|
const modelConfig = commonStore.getCurrentModelConfig();
|
||||||
|
const webgpu = modelConfig.modelParameters.device === 'WebGPU';
|
||||||
let modelName = '';
|
let modelName = '';
|
||||||
let modelPath = '';
|
let modelPath = '';
|
||||||
if (modelConfig && modelConfig.modelParameters) {
|
if (modelConfig && modelConfig.modelParameters) {
|
||||||
@ -50,9 +57,32 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const ok = await checkDependencies(navigate);
|
if (webgpu) {
|
||||||
if (!ok)
|
if (!['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) {
|
||||||
return;
|
const stModelPath = modelPath.replace(/\.pth$/, '.st');
|
||||||
|
if (await FileExists(stModelPath)) {
|
||||||
|
modelPath = stModelPath;
|
||||||
|
} else {
|
||||||
|
toast(t('Please convert model to safe tensors format first'), { type: 'error' });
|
||||||
|
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!webgpu) {
|
||||||
|
if (['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) {
|
||||||
|
toast(t('Please change Strategy to WebGPU to use safetensors format'), { type: 'error' });
|
||||||
|
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!webgpu) {
|
||||||
|
const ok = await checkDependencies(navigate);
|
||||||
|
if (!ok)
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const currentModelSource = commonStore.modelSourceList.find(item => item.name === modelName);
|
const currentModelSource = commonStore.modelSourceList.find(item => item.name === modelName);
|
||||||
|
|
||||||
@ -85,7 +115,12 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
|||||||
|
|
||||||
await exit(1000).catch(() => {
|
await exit(1000).catch(() => {
|
||||||
});
|
});
|
||||||
StartServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1',
|
|
||||||
|
const startServer = webgpu ?
|
||||||
|
(_: string, port: number, host: string) => StartWebGPUServer(port, host)
|
||||||
|
: StartServer;
|
||||||
|
|
||||||
|
startServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1',
|
||||||
modelConfig.modelParameters.device === 'CUDA-Beta'
|
modelConfig.modelParameters.device === 'CUDA-Beta'
|
||||||
).catch((e) => {
|
).catch((e) => {
|
||||||
const errMsg = e.message || e;
|
const errMsg = e.message || e;
|
||||||
@ -104,19 +139,23 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
|||||||
if (r.ok && !loading) {
|
if (r.ok && !loading) {
|
||||||
loading = true;
|
loading = true;
|
||||||
clearInterval(intervalId);
|
clearInterval(intervalId);
|
||||||
await getStatus().then(status => {
|
if (!webgpu) {
|
||||||
if (status)
|
await getStatus().then(status => {
|
||||||
commonStore.setStatus(status);
|
if (status)
|
||||||
});
|
commonStore.setStatus(status);
|
||||||
|
});
|
||||||
|
}
|
||||||
commonStore.setStatus({ status: ModelStatus.Loading });
|
commonStore.setStatus({ status: ModelStatus.Loading });
|
||||||
toast(t('Loading Model'), { type: 'info' });
|
toast(t('Loading Model'), { type: 'info' });
|
||||||
updateConfig({
|
if (!webgpu) {
|
||||||
max_tokens: modelConfig.apiParameters.maxResponseToken,
|
updateConfig({
|
||||||
temperature: modelConfig.apiParameters.temperature,
|
max_tokens: modelConfig.apiParameters.maxResponseToken,
|
||||||
top_p: modelConfig.apiParameters.topP,
|
temperature: modelConfig.apiParameters.temperature,
|
||||||
presence_penalty: modelConfig.apiParameters.presencePenalty,
|
top_p: modelConfig.apiParameters.topP,
|
||||||
frequency_penalty: modelConfig.apiParameters.frequencyPenalty
|
presence_penalty: modelConfig.apiParameters.presencePenalty,
|
||||||
});
|
frequency_penalty: modelConfig.apiParameters.frequencyPenalty
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const strategy = getStrategy(modelConfig);
|
const strategy = getStrategy(modelConfig);
|
||||||
let customCudaFile = '';
|
let customCudaFile = '';
|
||||||
|
@ -13,13 +13,14 @@ import { Page } from '../components/Page';
|
|||||||
import { useNavigate } from 'react-router';
|
import { useNavigate } from 'react-router';
|
||||||
import { RunButton } from '../components/RunButton';
|
import { RunButton } from '../components/RunButton';
|
||||||
import { updateConfig } from '../apis';
|
import { updateConfig } from '../apis';
|
||||||
import { ConvertModel, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App';
|
import { ConvertModel, ConvertSafetensors, FileExists, GetPyError } from '../../wailsjs/go/backend_golang/App';
|
||||||
import { getStrategy } from '../utils';
|
import { checkDependencies, getStrategy } from '../utils';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { WindowShow } from '../../wailsjs/runtime/runtime';
|
import { WindowShow } from '../../wailsjs/runtime/runtime';
|
||||||
import strategyImg from '../assets/images/strategy.jpg';
|
import strategyImg from '../assets/images/strategy.jpg';
|
||||||
import strategyZhImg from '../assets/images/strategy_zh.jpg';
|
import strategyZhImg from '../assets/images/strategy_zh.jpg';
|
||||||
import { ResetConfigsButton } from '../components/ResetConfigsButton';
|
import { ResetConfigsButton } from '../components/ResetConfigsButton';
|
||||||
|
import { useMediaQuery } from 'usehooks-ts';
|
||||||
|
|
||||||
export type ApiParameters = {
|
export type ApiParameters = {
|
||||||
apiPort: number
|
apiPort: number
|
||||||
@ -56,6 +57,7 @@ export const Configs: FC = observer(() => {
|
|||||||
const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex);
|
const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex);
|
||||||
const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]);
|
const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]);
|
||||||
const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false);
|
const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false);
|
||||||
|
const mq = useMediaQuery('(min-width: 640px)');
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const port = selectedConfig.apiParameters.apiPort;
|
const port = selectedConfig.apiParameters.apiPort;
|
||||||
|
|
||||||
@ -128,7 +130,8 @@ export const Configs: FC = observer(() => {
|
|||||||
setSelectedIndex(0);
|
setSelectedIndex(0);
|
||||||
setSelectedConfig(commonStore.modelConfigs[0]);
|
setSelectedConfig(commonStore.modelConfigs[0]);
|
||||||
}} />
|
}} />
|
||||||
<ToolTipButton desc={t('Save Config')} icon={<Save20Regular />} onClick={onClickSave} />
|
<ToolTipButton desc={mq ? '' : t('Save Config')} icon={<Save20Regular />} text={mq ? t('Save Config') : ''}
|
||||||
|
onClick={onClickSave} />
|
||||||
</div>
|
</div>
|
||||||
<div className="flex items-center gap-4">
|
<div className="flex items-center gap-4">
|
||||||
<Label>{t('Config Name')}</Label>
|
<Label>{t('Config Name')}</Label>
|
||||||
@ -237,40 +240,84 @@ export const Configs: FC = observer(() => {
|
|||||||
}} />
|
}} />
|
||||||
</div>
|
</div>
|
||||||
} />
|
} />
|
||||||
<ToolTipButton text={t('Convert')}
|
{
|
||||||
desc={t('Convert model with these configs. Using a converted model will greatly improve the loading speed, but model parameters of the converted model cannot be modified.')}
|
selectedConfig.modelParameters.device !== 'WebGPU' ?
|
||||||
onClick={async () => {
|
<ToolTipButton text={t('Convert')}
|
||||||
if (commonStore.platform === 'darwin') {
|
desc={t('Convert model with these configs. Using a converted model will greatly improve the loading speed, but model parameters of the converted model cannot be modified.')}
|
||||||
toast(t('MacOS is not yet supported for performing this operation, please do it manually.'), { type: 'info' });
|
onClick={async () => {
|
||||||
return;
|
if (commonStore.platform === 'darwin') {
|
||||||
} else if (commonStore.platform === 'linux') {
|
toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
|
||||||
toast(t('Linux is not yet supported for performing this operation, please do it manually.'), { type: 'info' });
|
return;
|
||||||
return;
|
} else if (commonStore.platform === 'linux') {
|
||||||
}
|
toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_model.py)', { type: 'info' });
|
||||||
|
return;
|
||||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
|
||||||
if (await FileExists(modelPath)) {
|
|
||||||
const strategy = getStrategy(selectedConfig);
|
|
||||||
const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-');
|
|
||||||
toast(t('Start Converting'), { autoClose: 1000, type: 'info' });
|
|
||||||
ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => {
|
|
||||||
if (!await FileExists(newModelPath + '.pth')) {
|
|
||||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
|
||||||
} else {
|
|
||||||
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
|
||||||
}
|
}
|
||||||
}).catch(e => {
|
|
||||||
const errMsg = e.message || e;
|
const ok = await checkDependencies(navigate);
|
||||||
if (errMsg.includes('path contains space'))
|
if (!ok)
|
||||||
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
return;
|
||||||
else
|
|
||||||
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||||
});
|
if (await FileExists(modelPath)) {
|
||||||
setTimeout(WindowShow, 1000);
|
const strategy = getStrategy(selectedConfig);
|
||||||
} else {
|
const newModelPath = modelPath + '-' + strategy.replace(/[:> *+]/g, '-');
|
||||||
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
toast(t('Start Converting'), { autoClose: 1000, type: 'info' });
|
||||||
}
|
ConvertModel(commonStore.settings.customPythonPath, modelPath, strategy, newModelPath).then(async () => {
|
||||||
}} />
|
if (!await FileExists(newModelPath + '.pth')) {
|
||||||
|
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||||
|
} else {
|
||||||
|
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||||
|
}
|
||||||
|
}).catch(e => {
|
||||||
|
const errMsg = e.message || e;
|
||||||
|
if (errMsg.includes('path contains space'))
|
||||||
|
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||||
|
else
|
||||||
|
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||||
|
});
|
||||||
|
setTimeout(WindowShow, 1000);
|
||||||
|
} else {
|
||||||
|
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||||
|
}
|
||||||
|
}} /> :
|
||||||
|
<ToolTipButton text={t('Convert To Safe Tensors Format')}
|
||||||
|
desc=""
|
||||||
|
onClick={async () => {
|
||||||
|
if (commonStore.platform === 'darwin') {
|
||||||
|
toast(t('MacOS is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_safetensors.py)', { type: 'info' });
|
||||||
|
return;
|
||||||
|
} else if (commonStore.platform === 'linux') {
|
||||||
|
toast(t('Linux is not yet supported for performing this operation, please do it manually.') + ' (backend-python/convert_safetensors.py)', { type: 'info' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ok = await checkDependencies(navigate);
|
||||||
|
if (!ok)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||||
|
if (await FileExists(modelPath)) {
|
||||||
|
toast(t('Start Converting'), { autoClose: 1000, type: 'info' });
|
||||||
|
const newModelPath = modelPath.replace(/\.pth$/, '.st');
|
||||||
|
ConvertSafetensors(commonStore.settings.customPythonPath, modelPath, newModelPath).then(async () => {
|
||||||
|
if (!await FileExists(newModelPath)) {
|
||||||
|
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||||
|
} else {
|
||||||
|
toast(`${t('Convert Success')} - ${newModelPath}`, { type: 'success' });
|
||||||
|
}
|
||||||
|
}).catch(e => {
|
||||||
|
const errMsg = e.message || e;
|
||||||
|
if (errMsg.includes('path contains space'))
|
||||||
|
toast(`${t('Convert Failed')} - ${t('File Path Cannot Contain Space')}`, { type: 'error' });
|
||||||
|
else
|
||||||
|
toast(`${t('Convert Failed')} - ${e.message || e}`, { type: 'error' });
|
||||||
|
});
|
||||||
|
setTimeout(WindowShow, 1000);
|
||||||
|
} else {
|
||||||
|
toast(`${t('Model Not Found')} - ${modelPath}`, { type: 'error' });
|
||||||
|
}
|
||||||
|
}} />
|
||||||
|
}
|
||||||
<Labeled label={t('Strategy')} content={
|
<Labeled label={t('Strategy')} content={
|
||||||
<Dropdown style={{ minWidth: 0 }} className="grow" value={t(selectedConfig.modelParameters.device)!}
|
<Dropdown style={{ minWidth: 0 }} className="grow" value={t(selectedConfig.modelParameters.device)!}
|
||||||
selectedOptions={[selectedConfig.modelParameters.device]}
|
selectedOptions={[selectedConfig.modelParameters.device]}
|
||||||
@ -285,7 +332,7 @@ export const Configs: FC = observer(() => {
|
|||||||
{commonStore.platform === 'darwin' && <Option value="MPS">MPS</Option>}
|
{commonStore.platform === 'darwin' && <Option value="MPS">MPS</Option>}
|
||||||
<Option value="CUDA">CUDA</Option>
|
<Option value="CUDA">CUDA</Option>
|
||||||
<Option value="CUDA-Beta">{t('CUDA (Beta, Faster)')!}</Option>
|
<Option value="CUDA-Beta">{t('CUDA (Beta, Faster)')!}</Option>
|
||||||
<Option value="WebGPU" disabled>WebGPU</Option>
|
<Option value="WebGPU">WebGPU</Option>
|
||||||
<Option value="Custom">{t('Custom')!}</Option>
|
<Option value="Custom">{t('Custom')!}</Option>
|
||||||
</Dropdown>
|
</Dropdown>
|
||||||
} />
|
} />
|
||||||
@ -305,7 +352,7 @@ export const Configs: FC = observer(() => {
|
|||||||
}}>
|
}}>
|
||||||
<Option>fp16</Option>
|
<Option>fp16</Option>
|
||||||
<Option>int8</Option>
|
<Option>int8</Option>
|
||||||
<Option>fp32</Option>
|
{selectedConfig.modelParameters.device !== 'WebGPU' && <Option>fp32</Option>}
|
||||||
</Dropdown>
|
</Dropdown>
|
||||||
} />
|
} />
|
||||||
}
|
}
|
||||||
@ -353,7 +400,7 @@ export const Configs: FC = observer(() => {
|
|||||||
}
|
}
|
||||||
{selectedConfig.modelParameters.device === 'Custom' && <div />}
|
{selectedConfig.modelParameters.device === 'Custom' && <div />}
|
||||||
{
|
{
|
||||||
selectedConfig.modelParameters.device !== 'CPU' && selectedConfig.modelParameters.device !== 'MPS' &&
|
(selectedConfig.modelParameters.device.includes('CUDA') || selectedConfig.modelParameters.device === 'Custom') &&
|
||||||
<Labeled label={t('Use Custom CUDA kernel to Accelerate')}
|
<Labeled label={t('Use Custom CUDA kernel to Accelerate')}
|
||||||
desc={t('Enabling this option can greatly improve inference speed and save some VRAM, but there may be compatibility issues. If it fails to start, please turn off this option.')}
|
desc={t('Enabling this option can greatly improve inference speed and save some VRAM, but there may be compatibility issues. If it fails to start, please turn off this option.')}
|
||||||
content={
|
content={
|
||||||
|
@ -57,6 +57,8 @@ export async function refreshBuiltInModels(readCache: boolean = false) {
|
|||||||
return cache;
|
return cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const modelSuffix = ['.pth', '.st', '.safetensors'];
|
||||||
|
|
||||||
export async function refreshLocalModels(cache: {
|
export async function refreshLocalModels(cache: {
|
||||||
models: ModelSourceItem[]
|
models: ModelSourceItem[]
|
||||||
}, filter: boolean = true, initUnfinishedModels: boolean = false) {
|
}, filter: boolean = true, initUnfinishedModels: boolean = false) {
|
||||||
@ -65,7 +67,7 @@ export async function refreshLocalModels(cache: {
|
|||||||
|
|
||||||
await ListDirFiles(commonStore.settings.customModelsPath).then((data) => {
|
await ListDirFiles(commonStore.settings.customModelsPath).then((data) => {
|
||||||
cache.models.push(...data.flatMap(d => {
|
cache.models.push(...data.flatMap(d => {
|
||||||
if (!d.isDir && d.name.endsWith('.pth'))
|
if (!d.isDir && modelSuffix.some((ext => d.name.endsWith(ext))))
|
||||||
return [{
|
return [{
|
||||||
name: d.name,
|
name: d.name,
|
||||||
size: d.size,
|
size: d.size,
|
||||||
@ -146,7 +148,7 @@ export async function refreshRemoteModels(cache: { models: ModelSourceItem[] })
|
|||||||
.catch(() => {
|
.catch(() => {
|
||||||
});
|
});
|
||||||
cache.models = cache.models.filter((model, index, self) => {
|
cache.models = cache.models.filter((model, index, self) => {
|
||||||
return model.name.endsWith('.pth')
|
return modelSuffix.some((ext => model.name.endsWith(ext)))
|
||||||
&& index === self.findIndex(
|
&& index === self.findIndex(
|
||||||
m => m.name === model.name || (m.SHA256 && m.SHA256 === model.SHA256 && m.size === model.size));
|
m => m.name === model.name || (m.SHA256 && m.SHA256 === model.SHA256 && m.size === model.size));
|
||||||
});
|
});
|
||||||
@ -176,6 +178,9 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
|
|||||||
strategy += 'cpu ';
|
strategy += 'cpu ';
|
||||||
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
|
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
|
||||||
break;
|
break;
|
||||||
|
case 'WebGPU':
|
||||||
|
strategy += params.precision === 'int8' ? 'fp16i8' : 'fp16';
|
||||||
|
break;
|
||||||
case 'CUDA':
|
case 'CUDA':
|
||||||
case 'CUDA-Beta':
|
case 'CUDA-Beta':
|
||||||
if (avoidOverflow)
|
if (avoidOverflow)
|
||||||
|
4
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
4
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
@ -10,6 +10,8 @@ export function ConvertData(arg1:string,arg2:string,arg3:string,arg4:string):Pro
|
|||||||
|
|
||||||
export function ConvertModel(arg1:string,arg2:string,arg3:string,arg4:string):Promise<string>;
|
export function ConvertModel(arg1:string,arg2:string,arg3:string,arg4:string):Promise<string>;
|
||||||
|
|
||||||
|
export function ConvertSafetensors(arg1:string,arg2:string,arg3:string):Promise<string>;
|
||||||
|
|
||||||
export function CopyFile(arg1:string,arg2:string):Promise<void>;
|
export function CopyFile(arg1:string,arg2:string):Promise<void>;
|
||||||
|
|
||||||
export function DeleteFile(arg1:string):Promise<void>;
|
export function DeleteFile(arg1:string):Promise<void>;
|
||||||
@ -48,6 +50,8 @@ export function SaveJson(arg1:string,arg2:any):Promise<void>;
|
|||||||
|
|
||||||
export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean):Promise<string>;
|
export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean):Promise<string>;
|
||||||
|
|
||||||
|
export function StartWebGPUServer(arg1:number,arg2:string):Promise<string>;
|
||||||
|
|
||||||
export function UpdateApp(arg1:string):Promise<boolean>;
|
export function UpdateApp(arg1:string):Promise<boolean>;
|
||||||
|
|
||||||
export function WslCommand(arg1:string):Promise<void>;
|
export function WslCommand(arg1:string):Promise<void>;
|
||||||
|
8
frontend/wailsjs/go/backend_golang/App.js
generated
8
frontend/wailsjs/go/backend_golang/App.js
generated
@ -18,6 +18,10 @@ export function ConvertModel(arg1, arg2, arg3, arg4) {
|
|||||||
return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3, arg4);
|
return window['go']['backend_golang']['App']['ConvertModel'](arg1, arg2, arg3, arg4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function ConvertSafetensors(arg1, arg2, arg3) {
|
||||||
|
return window['go']['backend_golang']['App']['ConvertSafetensors'](arg1, arg2, arg3);
|
||||||
|
}
|
||||||
|
|
||||||
export function CopyFile(arg1, arg2) {
|
export function CopyFile(arg1, arg2) {
|
||||||
return window['go']['backend_golang']['App']['CopyFile'](arg1, arg2);
|
return window['go']['backend_golang']['App']['CopyFile'](arg1, arg2);
|
||||||
}
|
}
|
||||||
@ -94,6 +98,10 @@ export function StartServer(arg1, arg2, arg3, arg4) {
|
|||||||
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4);
|
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function StartWebGPUServer(arg1, arg2) {
|
||||||
|
return window['go']['backend_golang']['App']['StartWebGPUServer'](arg1, arg2);
|
||||||
|
}
|
||||||
|
|
||||||
export function UpdateApp(arg1) {
|
export function UpdateApp(arg1) {
|
||||||
return window['go']['backend_golang']['App']['UpdateApp'](arg1);
|
return window['go']['backend_golang']['App']['UpdateApp'](arg1);
|
||||||
}
|
}
|
||||||
|
4
main.go
4
main.go
@ -49,6 +49,9 @@ var cyacInfo embed.FS
|
|||||||
//go:embed backend-python
|
//go:embed backend-python
|
||||||
var py embed.FS
|
var py embed.FS
|
||||||
|
|
||||||
|
//go:embed backend-rust
|
||||||
|
var webgpu embed.FS
|
||||||
|
|
||||||
//go:embed finetune
|
//go:embed finetune
|
||||||
var finetune embed.FS
|
var finetune embed.FS
|
||||||
|
|
||||||
@ -63,6 +66,7 @@ func main() {
|
|||||||
backend.CopyEmbed(cyac)
|
backend.CopyEmbed(cyac)
|
||||||
backend.CopyEmbed(cyacInfo)
|
backend.CopyEmbed(cyacInfo)
|
||||||
backend.CopyEmbed(py)
|
backend.CopyEmbed(py)
|
||||||
|
backend.CopyEmbed(webgpu)
|
||||||
backend.CopyEmbed(finetune)
|
backend.CopyEmbed(finetune)
|
||||||
backend.CopyEmbed(midi)
|
backend.CopyEmbed(midi)
|
||||||
backend.CopyEmbed(midiAssets)
|
backend.CopyEmbed(midiAssets)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
- ^backend-python/wkv_cuda_utils/
|
- ^backend-python/wkv_cuda_utils/
|
||||||
- ^backend-python/get-pip\.py
|
- ^backend-python/get-pip\.py
|
||||||
- ^backend-python/convert_model\.py
|
- ^backend-python/convert_model\.py
|
||||||
|
- ^backend-python/convert_safetensors\.py
|
||||||
- ^backend-python/utils/midi\.py
|
- ^backend-python/utils/midi\.py
|
||||||
- ^build/
|
- ^build/
|
||||||
- ^finetune/lora/
|
- ^finetune/lora/
|
||||||
|
Loading…
x
Reference in New Issue
Block a user