WebGPU (Python) strategy
This commit is contained in:
		
							parent
							
								
									2f45e9c33a
								
							
						
					
					
						commit
						18d4b2304e
					
				
							
								
								
									
										9
									
								
								backend-python/rwkv_pip/webgpu/model.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								backend-python/rwkv_pip/webgpu/model.py
									
									
									
									
										vendored
									
									
								
							@ -12,8 +12,13 @@ except ModuleNotFoundError:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RWKV:
 | 
			
		||||
    def __init__(self, model_path: str, strategy=None):
 | 
			
		||||
        self.model = wrp.v5.Model(model_path, turbo=False)
 | 
			
		||||
    def __init__(self, model_path: str, strategy: str = None):
 | 
			
		||||
        self.model = wrp.v5.Model(
 | 
			
		||||
            model_path,
 | 
			
		||||
            turbo=False,
 | 
			
		||||
            quant=32 if "i8" in strategy else None,
 | 
			
		||||
            quant_nf4=26 if "i4" in strategy else None,
 | 
			
		||||
        )
 | 
			
		||||
        self.w = {}  # fake weight
 | 
			
		||||
        self.w["emb.weight"] = [0] * wrp.peek_info(model_path).num_vocab
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -282,7 +282,7 @@ const Configs: FC = observer(() => {
 | 
			
		||||
                  selectedConfig.modelParameters.device !== 'Custom' && <Labeled label={t('Precision')}
 | 
			
		||||
                    desc={t('int8 uses less VRAM, but has slightly lower quality. fp16 has higher quality.')}
 | 
			
		||||
                    content={
 | 
			
		||||
                      <Dropdown disabled={selectedConfig.modelParameters.device === 'WebGPU (Python)'}
 | 
			
		||||
                      <Dropdown
 | 
			
		||||
                        style={{ minWidth: 0 }} className="grow"
 | 
			
		||||
                        value={selectedConfig.modelParameters.precision}
 | 
			
		||||
                        selectedOptions={[selectedConfig.modelParameters.precision]}
 | 
			
		||||
@ -296,8 +296,8 @@ const Configs: FC = observer(() => {
 | 
			
		||||
                        {selectedConfig.modelParameters.device !== 'CPU' && selectedConfig.modelParameters.device !== 'MPS' &&
 | 
			
		||||
                          <Option>fp16</Option>}
 | 
			
		||||
                        {selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && <Option>int8</Option>}
 | 
			
		||||
                        {selectedConfig.modelParameters.device === 'WebGPU' && <Option>nf4</Option>}
 | 
			
		||||
                        {selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && selectedConfig.modelParameters.device !== 'WebGPU' &&
 | 
			
		||||
                        {selectedConfig.modelParameters.device.startsWith('WebGPU') && <Option>nf4</Option>}
 | 
			
		||||
                        {selectedConfig.modelParameters.device !== 'CPU (rwkv.cpp)' && !selectedConfig.modelParameters.device.startsWith('WebGPU') &&
 | 
			
		||||
                          <Option>fp32</Option>}
 | 
			
		||||
                        {selectedConfig.modelParameters.device === 'CPU (rwkv.cpp)' && <Option>Q5_1</Option>}
 | 
			
		||||
                      </Dropdown>
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user