add WebGPU Python Mode (https://github.com/cryscan/web-rwkv-py)
This commit is contained in:
@@ -48,6 +48,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
|
||||
const modelConfig = commonStore.getCurrentModelConfig();
|
||||
const webgpu = modelConfig.modelParameters.device === 'WebGPU';
|
||||
const webgpuPython = modelConfig.modelParameters.device === 'WebGPU (Python)';
|
||||
const cpp = modelConfig.modelParameters.device === 'CPU (rwkv.cpp)';
|
||||
let modelName = '';
|
||||
let modelPath = '';
|
||||
@@ -77,7 +78,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
});
|
||||
};
|
||||
|
||||
if (webgpu) {
|
||||
if (webgpu || webgpuPython) {
|
||||
if (!['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) {
|
||||
const stModelPath = modelPath.replace(/\.pth$/, '.st');
|
||||
if (await FileExists(stModelPath)) {
|
||||
@@ -92,7 +93,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
return;
|
||||
} else {
|
||||
toastWithButton(t('Please convert model to safe tensors format first'), t('Convert'), () => {
|
||||
convertToSt(modelConfig);
|
||||
convertToSt(modelConfig, navigate);
|
||||
});
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
return;
|
||||
@@ -100,7 +101,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
}
|
||||
}
|
||||
|
||||
if (!webgpu) {
|
||||
if (!webgpu && !webgpuPython) {
|
||||
if (['.st', '.safetensors'].some(ext => modelPath.endsWith(ext))) {
|
||||
toast(t('Please change Strategy to WebGPU to use safetensors format'), { type: 'error' });
|
||||
commonStore.setStatus({ status: ModelStatus.Offline });
|
||||
@@ -176,7 +177,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
const isUsingCudaBeta = modelConfig.modelParameters.device === 'CUDA-Beta';
|
||||
|
||||
startServer(commonStore.settings.customPythonPath, port, commonStore.settings.host !== '127.0.0.1' ? '0.0.0.0' : '127.0.0.1',
|
||||
!!modelConfig.enableWebUI, isUsingCudaBeta, cpp
|
||||
!!modelConfig.enableWebUI, isUsingCudaBeta, cpp, webgpuPython
|
||||
).catch((e) => {
|
||||
const errMsg = e.message || e;
|
||||
if (errMsg.includes('path contains space'))
|
||||
@@ -216,7 +217,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
|
||||
const strategy = getStrategy(modelConfig);
|
||||
let customCudaFile = '';
|
||||
if ((modelConfig.modelParameters.device.includes('CUDA') || modelConfig.modelParameters.device === 'Custom')
|
||||
if ((modelConfig.modelParameters.device.startsWith('CUDA') || modelConfig.modelParameters.device === 'Custom')
|
||||
&& modelConfig.modelParameters.useCustomCuda
|
||||
&& !strategy.split('->').some(s => ['cuda', 'fp32'].every(v => s.includes(v)))) {
|
||||
if (commonStore.platform === 'windows') {
|
||||
@@ -264,7 +265,7 @@ export const RunButton: FC<{ onClickRun?: MouseEventHandler, iconMode?: boolean
|
||||
navigate({ pathname: '/' + buttonName.toLowerCase() });
|
||||
};
|
||||
|
||||
if ((modelConfig.modelParameters.device === 'CUDA' || modelConfig.modelParameters.device === 'CUDA-Beta') &&
|
||||
if (modelConfig.modelParameters.device.startsWith('CUDA') &&
|
||||
modelConfig.modelParameters.storedLayers < modelConfig.modelParameters.maxStoredLayers &&
|
||||
commonStore.monitorData && commonStore.monitorData.totalVram !== 0 &&
|
||||
(commonStore.monitorData.usedVram / commonStore.monitorData.totalVram) < 0.9)
|
||||
|
||||
@@ -246,7 +246,7 @@ const Configs: FC = observer(() => {
|
||||
</div>
|
||||
} />
|
||||
{
|
||||
selectedConfig.modelParameters.device !== 'WebGPU' ?
|
||||
!selectedConfig.modelParameters.device.startsWith('WebGPU') ?
|
||||
(selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' ?
|
||||
<ToolTipButton text={t('Convert')}
|
||||
desc={t('Convert model with these configs. Using a converted model will greatly improve the loading speed, but model parameters of the converted model cannot be modified.')}
|
||||
@@ -256,7 +256,7 @@ const Configs: FC = observer(() => {
|
||||
onClick={() => convertToGGML(selectedConfig, navigate)} />)
|
||||
: <ToolTipButton text={t('Convert To Safe Tensors Format')}
|
||||
desc=""
|
||||
onClick={() => convertToSt(selectedConfig)} />
|
||||
onClick={() => convertToSt(selectedConfig, navigate)} />
|
||||
}
|
||||
<Labeled label={t('Strategy')} content={
|
||||
<Dropdown style={{ minWidth: 0 }} className="grow" value={t(selectedConfig.modelParameters.device)!}
|
||||
@@ -274,6 +274,7 @@ const Configs: FC = observer(() => {
|
||||
<Option value="CUDA">CUDA</Option>
|
||||
<Option value="CUDA-Beta">{t('CUDA (Beta, Faster)')!}</Option>
|
||||
<Option value="WebGPU">WebGPU</Option>
|
||||
<Option value="WebGPU (Python)">WebGPU (Python)</Option>
|
||||
<Option value="Custom">{t('Custom')!}</Option>
|
||||
</Dropdown>
|
||||
} />
|
||||
@@ -281,7 +282,8 @@ const Configs: FC = observer(() => {
|
||||
selectedConfig.modelParameters.device !== 'Custom' && <Labeled label={t('Precision')}
|
||||
desc={t('int8 uses less VRAM, but has slightly lower quality. fp16 has higher quality.')}
|
||||
content={
|
||||
<Dropdown style={{ minWidth: 0 }} className="grow"
|
||||
<Dropdown disabled={selectedConfig.modelParameters.device === 'WebGPU (Python)'}
|
||||
style={{ minWidth: 0 }} className="grow"
|
||||
value={selectedConfig.modelParameters.precision}
|
||||
selectedOptions={[selectedConfig.modelParameters.precision]}
|
||||
onOptionSelect={(_, data) => {
|
||||
@@ -302,12 +304,12 @@ const Configs: FC = observer(() => {
|
||||
} />
|
||||
}
|
||||
{
|
||||
selectedConfig.modelParameters.device.includes('CUDA') &&
|
||||
selectedConfig.modelParameters.device.startsWith('CUDA') &&
|
||||
<Labeled label={t('Current Strategy')}
|
||||
content={<Text> {getStrategy(selectedConfig)} </Text>} />
|
||||
}
|
||||
{
|
||||
selectedConfig.modelParameters.device.includes('CUDA') &&
|
||||
selectedConfig.modelParameters.device.startsWith('CUDA') &&
|
||||
<Labeled label={t('Stored Layers')}
|
||||
desc={t('Number of the neural network layers loaded into VRAM, the more you load, the faster the speed, but it consumes more VRAM. (If your VRAM is not enough, it will fail to load)')}
|
||||
content={
|
||||
@@ -320,7 +322,7 @@ const Configs: FC = observer(() => {
|
||||
}} />
|
||||
} />
|
||||
}
|
||||
{selectedConfig.modelParameters.device.includes('CUDA') && <div />}
|
||||
{selectedConfig.modelParameters.device.startsWith('CUDA') && <div />}
|
||||
{
|
||||
displayStrategyImg &&
|
||||
<img style={{ width: '80vh', height: 'auto', zIndex: 100 }}
|
||||
@@ -345,7 +347,7 @@ const Configs: FC = observer(() => {
|
||||
}
|
||||
{selectedConfig.modelParameters.device === 'Custom' && <div />}
|
||||
{
|
||||
(selectedConfig.modelParameters.device.includes('CUDA') || selectedConfig.modelParameters.device === 'Custom') &&
|
||||
(selectedConfig.modelParameters.device.startsWith('CUDA') || selectedConfig.modelParameters.device === 'Custom') &&
|
||||
<Labeled label={t('Use Custom CUDA kernel to Accelerate')}
|
||||
desc={t('Enabling this option can greatly improve inference speed and save some VRAM, but there may be compatibility issues (output garbled). If it fails to start, please turn off this option, or try to upgrade your gpu driver.')}
|
||||
content={
|
||||
|
||||
@@ -6,7 +6,7 @@ export type ApiParameters = {
|
||||
presencePenalty: number;
|
||||
frequencyPenalty: number;
|
||||
}
|
||||
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'MPS' | 'Custom';
|
||||
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'WebGPU (Python)' | 'MPS' | 'Custom';
|
||||
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';
|
||||
export type ModelParameters = {
|
||||
// different models can not have the same name
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
ConvertGGML,
|
||||
ConvertModel,
|
||||
ConvertSafetensors,
|
||||
ConvertSafetensorsWithPython,
|
||||
FileExists,
|
||||
GetPyError
|
||||
} from '../../wailsjs/go/backend_golang/App';
|
||||
@@ -51,12 +52,22 @@ export const convertModel = async (selectedConfig: ModelConfig, navigate: Naviga
|
||||
};
|
||||
|
||||
|
||||
export const convertToSt = async (selectedConfig: ModelConfig) => {
|
||||
export const convertToSt = async (selectedConfig: ModelConfig, navigate: NavigateFunction) => {
|
||||
const webgpuPython = selectedConfig.modelParameters.device === 'WebGPU (Python)';
|
||||
if (webgpuPython) {
|
||||
const ok = await checkDependencies(navigate);
|
||||
if (!ok)
|
||||
return;
|
||||
}
|
||||
|
||||
const modelPath = `${commonStore.settings.customModelsPath}/${selectedConfig.modelParameters.modelName}`;
|
||||
if (await FileExists(modelPath)) {
|
||||
toast(t('Start Converting'), { autoClose: 2000, type: 'info' });
|
||||
const newModelPath = modelPath.replace(/\.pth$/, '.st');
|
||||
ConvertSafetensors(modelPath, newModelPath).then(async () => {
|
||||
const convert = webgpuPython ?
|
||||
(input: string, output: string) => ConvertSafetensorsWithPython(commonStore.settings.customPythonPath, input, output)
|
||||
: ConvertSafetensors;
|
||||
convert(modelPath, newModelPath).then(async () => {
|
||||
if (!await FileExists(newModelPath)) {
|
||||
if (commonStore.platform === 'windows' || commonStore.platform === 'linux')
|
||||
toast(t('Convert Failed') + ' - ' + await GetPyError(), { type: 'error' });
|
||||
|
||||
@@ -192,6 +192,7 @@ export const getStrategy = (modelConfig: ModelConfig | undefined = undefined) =>
|
||||
strategy += params.precision === 'int8' ? 'fp32i8' : 'fp32';
|
||||
break;
|
||||
case 'WebGPU':
|
||||
case 'WebGPU (Python)':
|
||||
strategy += params.precision === 'nf4' ? 'fp16i4' : params.precision === 'int8' ? 'fp16i8' : 'fp16';
|
||||
break;
|
||||
case 'CUDA':
|
||||
@@ -307,7 +308,7 @@ export function getServerRoot(defaultLocalPort: number, isCore: boolean = false)
|
||||
const coreCustomApiUrl = commonStore.settings.coreApiUrl.trim().replace(/\/$/, '');
|
||||
if (isCore && coreCustomApiUrl)
|
||||
return coreCustomApiUrl;
|
||||
|
||||
|
||||
const defaultRoot = `http://127.0.0.1:${defaultLocalPort}`;
|
||||
if (commonStore.status.status !== ModelStatus.Offline)
|
||||
return defaultRoot;
|
||||
|
||||
4
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
4
frontend/wailsjs/go/backend_golang/App.d.ts
generated
vendored
@@ -16,6 +16,8 @@ export function ConvertModel(arg1:string,arg2:string,arg3:string,arg4:string):Pr
|
||||
|
||||
export function ConvertSafetensors(arg1:string,arg2:string):Promise<string>;
|
||||
|
||||
export function ConvertSafetensorsWithPython(arg1:string,arg2:string,arg3:string):Promise<string>;
|
||||
|
||||
export function CopyFile(arg1:string,arg2:string):Promise<void>;
|
||||
|
||||
export function DeleteFile(arg1:string):Promise<void>;
|
||||
@@ -64,7 +66,7 @@ export function SaveJson(arg1:string,arg2:any):Promise<void>;
|
||||
|
||||
export function StartFile(arg1:string):Promise<void>;
|
||||
|
||||
export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean,arg6:boolean):Promise<string>;
|
||||
export function StartServer(arg1:string,arg2:number,arg3:string,arg4:boolean,arg5:boolean,arg6:boolean,arg7:boolean):Promise<string>;
|
||||
|
||||
export function StartWebGPUServer(arg1:number,arg2:string):Promise<string>;
|
||||
|
||||
|
||||
8
frontend/wailsjs/go/backend_golang/App.js
generated
8
frontend/wailsjs/go/backend_golang/App.js
generated
@@ -30,6 +30,10 @@ export function ConvertSafetensors(arg1, arg2) {
|
||||
return window['go']['backend_golang']['App']['ConvertSafetensors'](arg1, arg2);
|
||||
}
|
||||
|
||||
export function ConvertSafetensorsWithPython(arg1, arg2, arg3) {
|
||||
return window['go']['backend_golang']['App']['ConvertSafetensorsWithPython'](arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
export function CopyFile(arg1, arg2) {
|
||||
return window['go']['backend_golang']['App']['CopyFile'](arg1, arg2);
|
||||
}
|
||||
@@ -126,8 +130,8 @@ export function StartFile(arg1) {
|
||||
return window['go']['backend_golang']['App']['StartFile'](arg1);
|
||||
}
|
||||
|
||||
export function StartServer(arg1, arg2, arg3, arg4, arg5, arg6) {
|
||||
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5, arg6);
|
||||
export function StartServer(arg1, arg2, arg3, arg4, arg5, arg6, arg7) {
|
||||
return window['go']['backend_golang']['App']['StartServer'](arg1, arg2, arg3, arg4, arg5, arg6, arg7);
|
||||
}
|
||||
|
||||
export function StartWebGPUServer(arg1, arg2) {
|
||||
|
||||
Reference in New Issue
Block a user