expose global_penalty

This commit is contained in:
josc146 2024-03-02 17:50:41 +08:00
parent 53a5574080
commit 4f14074a75
6 changed files with 92 additions and 31 deletions

View File

@ -40,6 +40,7 @@ class AbstractRWKV(ABC):
self.penalty_alpha_presence = 0 self.penalty_alpha_presence = 0
self.penalty_alpha_frequency = 1 self.penalty_alpha_frequency = 1
self.penalty_decay = 0.996 self.penalty_decay = 0.996
self.global_penalty = False
@abstractmethod @abstractmethod
def adjust_occurrence(self, occurrence: Dict, token: int): def adjust_occurrence(self, occurrence: Dict, token: int):
@ -403,8 +404,8 @@ class TextRWKV(AbstractRWKV):
+ occurrence[n] * self.penalty_alpha_frequency + occurrence[n] * self.penalty_alpha_frequency
) )
# comment the codes below to get the same generated results as the official RWKV Gradio # set global_penalty to False to get the same generated results as the official RWKV Gradio
if i == 0: if self.global_penalty and i == 0:
for token in self.model_tokens: for token in self.model_tokens:
token = int(token) token = int(token)
if token not in self.AVOID_PENALTY_TOKENS: if token not in self.AVOID_PENALTY_TOKENS:
@ -673,6 +674,7 @@ class ModelConfigBody(BaseModel):
frequency_penalty: float = Field(default=None, ge=-2, le=2) frequency_penalty: float = Field(default=None, ge=-2, le=2)
penalty_decay: float = Field(default=None, ge=0.99, le=0.999) penalty_decay: float = Field(default=None, ge=0.99, le=0.999)
top_k: int = Field(default=None, ge=0, le=25) top_k: int = Field(default=None, ge=0, le=25)
global_penalty: bool = Field(default=None)
model_config = { model_config = {
"json_schema_extra": { "json_schema_extra": {
@ -683,6 +685,7 @@ class ModelConfigBody(BaseModel):
"presence_penalty": 0, "presence_penalty": 0,
"frequency_penalty": 1, "frequency_penalty": 1,
"penalty_decay": 0.996, "penalty_decay": 0.996,
"global_penalty": False,
} }
} }
} }
@ -706,6 +709,8 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
model.penalty_decay = body.penalty_decay model.penalty_decay = body.penalty_decay
if body.top_k is not None: if body.top_k is not None:
model.top_k = body.top_k model.top_k = body.top_k
if body.global_penalty is not None:
model.global_penalty = body.global_penalty
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody: def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
@ -717,4 +722,5 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
frequency_penalty=model.penalty_alpha_frequency, frequency_penalty=model.penalty_alpha_frequency,
penalty_decay=model.penalty_decay, penalty_decay=model.penalty_decay,
top_k=model.top_k, top_k=model.top_k,
global_penalty=model.global_penalty,
) )

View File

@ -345,5 +345,7 @@
"Quantized Layers": "量子化されたレイヤー", "Quantized Layers": "量子化されたレイヤー",
"Number of the neural network layers quantized with current precision, the more you quantize, the lower the VRAM usage, but the quality correspondingly decreases.": "現在の精度で量子化されたニューラルネットワークのレイヤーの数、量子化するほどVRAMの使用量が低くなりますが、品質も相応に低下します。", "Number of the neural network layers quantized with current precision, the more you quantize, the lower the VRAM usage, but the quality correspondingly decreases.": "現在の精度で量子化されたニューラルネットワークのレイヤーの数、量子化するほどVRAMの使用量が低くなりますが、品質も相応に低下します。",
"Parallel Token Chunk Size": "並列トークンチャンクサイズ", "Parallel Token Chunk Size": "並列トークンチャンクサイズ",
"Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一度に並列で処理される最大トークン数。高性能なGPUの場合、64または128になります高速。" "Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一度に並列で処理される最大トークン数。高性能なGPUの場合、64または128になります高速。",
"Global Penalty": "グローバルペナルティ",
"When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.": "レスポンスを生成する際、提出されたプロンプトをペナルティ要因として含めるかどうか。これをオフにすると、公式RWKV Gradioと同じ生成結果を得ることができます。生成された結果に重複がある場合、これをオンにすることで重複の生成を回避するのに役立ちます。"
} }

View File

@ -345,5 +345,7 @@
"Quantized Layers": "量化层数", "Quantized Layers": "量化层数",
"Number of the neural network layers quantized with current precision, the more you quantize, the lower the VRAM usage, but the quality correspondingly decreases.": "神经网络以当前精度量化的层数, 量化越多, 占用显存越低, 但质量相应下降", "Number of the neural network layers quantized with current precision, the more you quantize, the lower the VRAM usage, but the quality correspondingly decreases.": "神经网络以当前精度量化的层数, 量化越多, 占用显存越低, 但质量相应下降",
"Parallel Token Chunk Size": "并行Token块大小", "Parallel Token Chunk Size": "并行Token块大小",
"Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一次最多可以并行处理的token数量. 对于高端显卡, 这可以是64或128 (更快)" "Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一次最多可以并行处理的token数量. 对于高端显卡, 这可以是64或128 (更快)",
"Global Penalty": "全局惩罚",
"When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.": "生成响应时, 是否将提交的prompt也纳入到惩罚项. 关闭此项将得到与RWKV官方Gradio完全一致的生成结果. 如果你发现生成结果出现重复, 那么开启此项有助于避免生成重复"
} }

View File

@ -35,6 +35,7 @@ import { ResetConfigsButton } from '../components/ResetConfigsButton';
import { useMediaQuery } from 'usehooks-ts'; import { useMediaQuery } from 'usehooks-ts';
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs'; import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model'; import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
import { defaultPenaltyDecay } from './defaultConfigs';
const ConfigSelector: FC<{ const ConfigSelector: FC<{
selectedIndex: number, selectedIndex: number,
@ -66,14 +67,17 @@ const Configs: FC = observer(() => {
const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex); const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex);
const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]); const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]);
const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false); const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false);
const advancedHeaderRef = useRef<HTMLDivElement>(null); const advancedHeaderRef1 = useRef<HTMLDivElement>(null);
const advancedHeaderRef2 = useRef<HTMLDivElement>(null);
const mq = useMediaQuery('(min-width: 640px)'); const mq = useMediaQuery('(min-width: 640px)');
const navigate = useNavigate(); const navigate = useNavigate();
const port = selectedConfig.apiParameters.apiPort; const port = selectedConfig.apiParameters.apiPort;
useEffect(() => { useEffect(() => {
if (advancedHeaderRef.current) if (advancedHeaderRef1.current)
(advancedHeaderRef.current.firstElementChild as HTMLElement).style.padding = '0'; (advancedHeaderRef1.current.firstElementChild as HTMLElement).style.padding = '0';
if (advancedHeaderRef2.current)
(advancedHeaderRef2.current.firstElementChild as HTMLElement).style.padding = '0';
}, []); }, []);
const updateSelectedIndex = useCallback((newIndex: number) => { const updateSelectedIndex = useCallback((newIndex: number) => {
@ -113,7 +117,9 @@ const Configs: FC = observer(() => {
temperature: selectedConfig.apiParameters.temperature, temperature: selectedConfig.apiParameters.temperature,
top_p: selectedConfig.apiParameters.topP, top_p: selectedConfig.apiParameters.topP,
presence_penalty: selectedConfig.apiParameters.presencePenalty, presence_penalty: selectedConfig.apiParameters.presencePenalty,
frequency_penalty: selectedConfig.apiParameters.frequencyPenalty frequency_penalty: selectedConfig.apiParameters.frequencyPenalty,
penalty_decay: selectedConfig.apiParameters.penaltyDecay,
global_penalty: selectedConfig.apiParameters.globalPenalty
}); });
toast(t('Config Saved'), { autoClose: 300, type: 'success' }); toast(t('Config Saved'), { autoClose: 300, type: 'success' });
}; };
@ -194,6 +200,16 @@ const Configs: FC = observer(() => {
}); });
}} /> }} />
} /> } />
<Accordion className="sm:col-span-2" collapsible
openItems={!commonStore.apiParamsCollapsed && 'advanced'}
onToggle={(e, data) => {
if (data.value === 'advanced')
commonStore.setApiParamsCollapsed(!commonStore.apiParamsCollapsed);
}}>
<AccordionItem value="advanced">
<AccordionHeader ref={advancedHeaderRef1} size="small">{t('Advanced')}</AccordionHeader>
<AccordionPanel>
<div className="grid grid-cols-1 sm:grid-cols-2 gap-2">
<Labeled label={t('Presence Penalty') + ' *'} <Labeled label={t('Presence Penalty') + ' *'}
desc={t('Positive values penalize new tokens based on whether they appear in the text so far, increasing the model\'s likelihood to talk about new topics.')} desc={t('Positive values penalize new tokens based on whether they appear in the text so far, increasing the model\'s likelihood to talk about new topics.')}
content={ content={
@ -216,6 +232,35 @@ const Configs: FC = observer(() => {
}); });
}} /> }} />
} /> } />
<Labeled
label={t('Penalty Decay')
+ ((!selectedConfig.apiParameters.penaltyDecay || selectedConfig.apiParameters.penaltyDecay === defaultPenaltyDecay)
? ` (${t('Default')})` : '')
+ ' *'}
desc={t('If you don\'t know what it is, keep it default.')}
content={
<ValuedSlider value={selectedConfig.apiParameters.penaltyDecay || defaultPenaltyDecay}
min={0.99} max={0.999} step={0.001} toFixed={3} input
onChange={(e, data) => {
setSelectedConfigApiParams({
penaltyDecay: data.value
});
}} />
} />
<Labeled label={t('Global Penalty') + ' *'}
desc={t('When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.')}
content={
<Switch checked={selectedConfig.apiParameters.globalPenalty}
onChange={(e, data) => {
setSelectedConfigApiParams({
globalPenalty: data.checked
});
}} />
} />
</div>
</AccordionPanel>
</AccordionItem>
</Accordion>
</div> </div>
} }
/> />
@ -410,7 +455,7 @@ const Configs: FC = observer(() => {
commonStore.setModelParamsCollapsed(!commonStore.modelParamsCollapsed); commonStore.setModelParamsCollapsed(!commonStore.modelParamsCollapsed);
}}> }}>
<AccordionItem value="advanced"> <AccordionItem value="advanced">
<AccordionHeader ref={advancedHeaderRef} size="small">{t('Advanced')}</AccordionHeader> <AccordionHeader ref={advancedHeaderRef2} size="small">{t('Advanced')}</AccordionHeader>
<AccordionPanel> <AccordionPanel>
<div className="flex flex-col"> <div className="flex flex-col">
<div className="flex grow"> <div className="flex grow">

View File

@ -127,6 +127,7 @@ class CommonStore {
// configs // configs
currentModelConfigIndex: number = 0; currentModelConfigIndex: number = 0;
modelConfigs: ModelConfig[] = []; modelConfigs: ModelConfig[] = [];
apiParamsCollapsed: boolean = true;
modelParamsCollapsed: boolean = true; modelParamsCollapsed: boolean = true;
// models // models
activeModelListTags: string[] = []; activeModelListTags: string[] = [];
@ -324,6 +325,10 @@ class CommonStore {
this.advancedCollapsed = value; this.advancedCollapsed = value;
} }
setApiParamsCollapsed(value: boolean) {
this.apiParamsCollapsed = value;
}
setModelParamsCollapsed(value: boolean) { setModelParamsCollapsed(value: boolean) {
this.modelParamsCollapsed = value; this.modelParamsCollapsed = value;
} }

View File

@ -6,6 +6,7 @@ export type ApiParameters = {
presencePenalty: number; presencePenalty: number;
frequencyPenalty: number; frequencyPenalty: number;
penaltyDecay?: number; penaltyDecay?: number;
globalPenalty?: boolean;
} }
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'WebGPU (Python)' | 'MPS' | 'Custom'; export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'WebGPU (Python)' | 'MPS' | 'Custom';
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1'; export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';