expose global_penalty
This commit is contained in:
parent
53a5574080
commit
4f14074a75
@ -40,6 +40,7 @@ class AbstractRWKV(ABC):
|
|||||||
self.penalty_alpha_presence = 0
|
self.penalty_alpha_presence = 0
|
||||||
self.penalty_alpha_frequency = 1
|
self.penalty_alpha_frequency = 1
|
||||||
self.penalty_decay = 0.996
|
self.penalty_decay = 0.996
|
||||||
|
self.global_penalty = False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def adjust_occurrence(self, occurrence: Dict, token: int):
|
def adjust_occurrence(self, occurrence: Dict, token: int):
|
||||||
@ -403,8 +404,8 @@ class TextRWKV(AbstractRWKV):
|
|||||||
+ occurrence[n] * self.penalty_alpha_frequency
|
+ occurrence[n] * self.penalty_alpha_frequency
|
||||||
)
|
)
|
||||||
|
|
||||||
# comment the codes below to get the same generated results as the official RWKV Gradio
|
# set global_penalty to False to get the same generated results as the official RWKV Gradio
|
||||||
if i == 0:
|
if self.global_penalty and i == 0:
|
||||||
for token in self.model_tokens:
|
for token in self.model_tokens:
|
||||||
token = int(token)
|
token = int(token)
|
||||||
if token not in self.AVOID_PENALTY_TOKENS:
|
if token not in self.AVOID_PENALTY_TOKENS:
|
||||||
@ -673,6 +674,7 @@ class ModelConfigBody(BaseModel):
|
|||||||
frequency_penalty: float = Field(default=None, ge=-2, le=2)
|
frequency_penalty: float = Field(default=None, ge=-2, le=2)
|
||||||
penalty_decay: float = Field(default=None, ge=0.99, le=0.999)
|
penalty_decay: float = Field(default=None, ge=0.99, le=0.999)
|
||||||
top_k: int = Field(default=None, ge=0, le=25)
|
top_k: int = Field(default=None, ge=0, le=25)
|
||||||
|
global_penalty: bool = Field(default=None)
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {
|
||||||
@ -683,6 +685,7 @@ class ModelConfigBody(BaseModel):
|
|||||||
"presence_penalty": 0,
|
"presence_penalty": 0,
|
||||||
"frequency_penalty": 1,
|
"frequency_penalty": 1,
|
||||||
"penalty_decay": 0.996,
|
"penalty_decay": 0.996,
|
||||||
|
"global_penalty": False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -706,6 +709,8 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody):
|
|||||||
model.penalty_decay = body.penalty_decay
|
model.penalty_decay = body.penalty_decay
|
||||||
if body.top_k is not None:
|
if body.top_k is not None:
|
||||||
model.top_k = body.top_k
|
model.top_k = body.top_k
|
||||||
|
if body.global_penalty is not None:
|
||||||
|
model.global_penalty = body.global_penalty
|
||||||
|
|
||||||
|
|
||||||
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
||||||
@ -717,4 +722,5 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody:
|
|||||||
frequency_penalty=model.penalty_alpha_frequency,
|
frequency_penalty=model.penalty_alpha_frequency,
|
||||||
penalty_decay=model.penalty_decay,
|
penalty_decay=model.penalty_decay,
|
||||||
top_k=model.top_k,
|
top_k=model.top_k,
|
||||||
|
global_penalty=model.global_penalty,
|
||||||
)
|
)
|
||||||
|
@ -345,5 +345,7 @@
|
|||||||
"Quantized Layers": "量子化されたレイヤー",
|
"Quantized Layers": "量子化されたレイヤー",
|
||||||
"Number of the neural network layers quantized with current precision, the more you quantize, the lower the VRAM usage, but the quality correspondingly decreases.": "現在の精度で量子化されたニューラルネットワークのレイヤーの数、量子化するほどVRAMの使用量が低くなりますが、品質も相応に低下します。",
|
"Number of the neural network layers quantized with current precision, the more you quantize, the lower the VRAM usage, but the quality correspondingly decreases.": "現在の精度で量子化されたニューラルネットワークのレイヤーの数、量子化するほどVRAMの使用量が低くなりますが、品質も相応に低下します。",
|
||||||
"Parallel Token Chunk Size": "並列トークンチャンクサイズ",
|
"Parallel Token Chunk Size": "並列トークンチャンクサイズ",
|
||||||
"Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一度に並列で処理される最大トークン数。高性能なGPUの場合、64または128になります(高速)。"
|
"Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一度に並列で処理される最大トークン数。高性能なGPUの場合、64または128になります(高速)。",
|
||||||
|
"Global Penalty": "グローバルペナルティ",
|
||||||
|
"When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.": "レスポンスを生成する際、提出されたプロンプトをペナルティ要因として含めるかどうか。これをオフにすると、公式RWKV Gradioと同じ生成結果を得ることができます。生成された結果に重複がある場合、これをオンにすることで重複の生成を回避するのに役立ちます。"
|
||||||
}
|
}
|
@ -345,5 +345,7 @@
|
|||||||
"Quantized Layers": "量化层数",
|
"Quantized Layers": "量化层数",
|
||||||
"Number of the neural network layers quantized with current precision, the more you quantize, the lower the VRAM usage, but the quality correspondingly decreases.": "神经网络以当前精度量化的层数, 量化越多, 占用显存越低, 但质量相应下降",
|
"Number of the neural network layers quantized with current precision, the more you quantize, the lower the VRAM usage, but the quality correspondingly decreases.": "神经网络以当前精度量化的层数, 量化越多, 占用显存越低, 但质量相应下降",
|
||||||
"Parallel Token Chunk Size": "并行Token块大小",
|
"Parallel Token Chunk Size": "并行Token块大小",
|
||||||
"Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一次最多可以并行处理的token数量. 对于高端显卡, 这可以是64或128 (更快)"
|
"Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一次最多可以并行处理的token数量. 对于高端显卡, 这可以是64或128 (更快)",
|
||||||
|
"Global Penalty": "全局惩罚",
|
||||||
|
"When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.": "生成响应时, 是否将提交的prompt也纳入到惩罚项. 关闭此项将得到与RWKV官方Gradio完全一致的生成结果. 如果你发现生成结果出现重复, 那么开启此项有助于避免生成重复"
|
||||||
}
|
}
|
@ -35,6 +35,7 @@ import { ResetConfigsButton } from '../components/ResetConfigsButton';
|
|||||||
import { useMediaQuery } from 'usehooks-ts';
|
import { useMediaQuery } from 'usehooks-ts';
|
||||||
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
|
import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs';
|
||||||
import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
|
import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model';
|
||||||
|
import { defaultPenaltyDecay } from './defaultConfigs';
|
||||||
|
|
||||||
const ConfigSelector: FC<{
|
const ConfigSelector: FC<{
|
||||||
selectedIndex: number,
|
selectedIndex: number,
|
||||||
@ -66,14 +67,17 @@ const Configs: FC = observer(() => {
|
|||||||
const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex);
|
const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex);
|
||||||
const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]);
|
const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]);
|
||||||
const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false);
|
const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false);
|
||||||
const advancedHeaderRef = useRef<HTMLDivElement>(null);
|
const advancedHeaderRef1 = useRef<HTMLDivElement>(null);
|
||||||
|
const advancedHeaderRef2 = useRef<HTMLDivElement>(null);
|
||||||
const mq = useMediaQuery('(min-width: 640px)');
|
const mq = useMediaQuery('(min-width: 640px)');
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const port = selectedConfig.apiParameters.apiPort;
|
const port = selectedConfig.apiParameters.apiPort;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (advancedHeaderRef.current)
|
if (advancedHeaderRef1.current)
|
||||||
(advancedHeaderRef.current.firstElementChild as HTMLElement).style.padding = '0';
|
(advancedHeaderRef1.current.firstElementChild as HTMLElement).style.padding = '0';
|
||||||
|
if (advancedHeaderRef2.current)
|
||||||
|
(advancedHeaderRef2.current.firstElementChild as HTMLElement).style.padding = '0';
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const updateSelectedIndex = useCallback((newIndex: number) => {
|
const updateSelectedIndex = useCallback((newIndex: number) => {
|
||||||
@ -113,7 +117,9 @@ const Configs: FC = observer(() => {
|
|||||||
temperature: selectedConfig.apiParameters.temperature,
|
temperature: selectedConfig.apiParameters.temperature,
|
||||||
top_p: selectedConfig.apiParameters.topP,
|
top_p: selectedConfig.apiParameters.topP,
|
||||||
presence_penalty: selectedConfig.apiParameters.presencePenalty,
|
presence_penalty: selectedConfig.apiParameters.presencePenalty,
|
||||||
frequency_penalty: selectedConfig.apiParameters.frequencyPenalty
|
frequency_penalty: selectedConfig.apiParameters.frequencyPenalty,
|
||||||
|
penalty_decay: selectedConfig.apiParameters.penaltyDecay,
|
||||||
|
global_penalty: selectedConfig.apiParameters.globalPenalty
|
||||||
});
|
});
|
||||||
toast(t('Config Saved'), { autoClose: 300, type: 'success' });
|
toast(t('Config Saved'), { autoClose: 300, type: 'success' });
|
||||||
};
|
};
|
||||||
@ -194,6 +200,16 @@ const Configs: FC = observer(() => {
|
|||||||
});
|
});
|
||||||
}} />
|
}} />
|
||||||
} />
|
} />
|
||||||
|
<Accordion className="sm:col-span-2" collapsible
|
||||||
|
openItems={!commonStore.apiParamsCollapsed && 'advanced'}
|
||||||
|
onToggle={(e, data) => {
|
||||||
|
if (data.value === 'advanced')
|
||||||
|
commonStore.setApiParamsCollapsed(!commonStore.apiParamsCollapsed);
|
||||||
|
}}>
|
||||||
|
<AccordionItem value="advanced">
|
||||||
|
<AccordionHeader ref={advancedHeaderRef1} size="small">{t('Advanced')}</AccordionHeader>
|
||||||
|
<AccordionPanel>
|
||||||
|
<div className="grid grid-cols-1 sm:grid-cols-2 gap-2">
|
||||||
<Labeled label={t('Presence Penalty') + ' *'}
|
<Labeled label={t('Presence Penalty') + ' *'}
|
||||||
desc={t('Positive values penalize new tokens based on whether they appear in the text so far, increasing the model\'s likelihood to talk about new topics.')}
|
desc={t('Positive values penalize new tokens based on whether they appear in the text so far, increasing the model\'s likelihood to talk about new topics.')}
|
||||||
content={
|
content={
|
||||||
@ -216,6 +232,35 @@ const Configs: FC = observer(() => {
|
|||||||
});
|
});
|
||||||
}} />
|
}} />
|
||||||
} />
|
} />
|
||||||
|
<Labeled
|
||||||
|
label={t('Penalty Decay')
|
||||||
|
+ ((!selectedConfig.apiParameters.penaltyDecay || selectedConfig.apiParameters.penaltyDecay === defaultPenaltyDecay)
|
||||||
|
? ` (${t('Default')})` : '')
|
||||||
|
+ ' *'}
|
||||||
|
desc={t('If you don\'t know what it is, keep it default.')}
|
||||||
|
content={
|
||||||
|
<ValuedSlider value={selectedConfig.apiParameters.penaltyDecay || defaultPenaltyDecay}
|
||||||
|
min={0.99} max={0.999} step={0.001} toFixed={3} input
|
||||||
|
onChange={(e, data) => {
|
||||||
|
setSelectedConfigApiParams({
|
||||||
|
penaltyDecay: data.value
|
||||||
|
});
|
||||||
|
}} />
|
||||||
|
} />
|
||||||
|
<Labeled label={t('Global Penalty') + ' *'}
|
||||||
|
desc={t('When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.')}
|
||||||
|
content={
|
||||||
|
<Switch checked={selectedConfig.apiParameters.globalPenalty}
|
||||||
|
onChange={(e, data) => {
|
||||||
|
setSelectedConfigApiParams({
|
||||||
|
globalPenalty: data.checked
|
||||||
|
});
|
||||||
|
}} />
|
||||||
|
} />
|
||||||
|
</div>
|
||||||
|
</AccordionPanel>
|
||||||
|
</AccordionItem>
|
||||||
|
</Accordion>
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
@ -410,7 +455,7 @@ const Configs: FC = observer(() => {
|
|||||||
commonStore.setModelParamsCollapsed(!commonStore.modelParamsCollapsed);
|
commonStore.setModelParamsCollapsed(!commonStore.modelParamsCollapsed);
|
||||||
}}>
|
}}>
|
||||||
<AccordionItem value="advanced">
|
<AccordionItem value="advanced">
|
||||||
<AccordionHeader ref={advancedHeaderRef} size="small">{t('Advanced')}</AccordionHeader>
|
<AccordionHeader ref={advancedHeaderRef2} size="small">{t('Advanced')}</AccordionHeader>
|
||||||
<AccordionPanel>
|
<AccordionPanel>
|
||||||
<div className="flex flex-col">
|
<div className="flex flex-col">
|
||||||
<div className="flex grow">
|
<div className="flex grow">
|
||||||
|
@ -127,6 +127,7 @@ class CommonStore {
|
|||||||
// configs
|
// configs
|
||||||
currentModelConfigIndex: number = 0;
|
currentModelConfigIndex: number = 0;
|
||||||
modelConfigs: ModelConfig[] = [];
|
modelConfigs: ModelConfig[] = [];
|
||||||
|
apiParamsCollapsed: boolean = true;
|
||||||
modelParamsCollapsed: boolean = true;
|
modelParamsCollapsed: boolean = true;
|
||||||
// models
|
// models
|
||||||
activeModelListTags: string[] = [];
|
activeModelListTags: string[] = [];
|
||||||
@ -324,6 +325,10 @@ class CommonStore {
|
|||||||
this.advancedCollapsed = value;
|
this.advancedCollapsed = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setApiParamsCollapsed(value: boolean) {
|
||||||
|
this.apiParamsCollapsed = value;
|
||||||
|
}
|
||||||
|
|
||||||
setModelParamsCollapsed(value: boolean) {
|
setModelParamsCollapsed(value: boolean) {
|
||||||
this.modelParamsCollapsed = value;
|
this.modelParamsCollapsed = value;
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ export type ApiParameters = {
|
|||||||
presencePenalty: number;
|
presencePenalty: number;
|
||||||
frequencyPenalty: number;
|
frequencyPenalty: number;
|
||||||
penaltyDecay?: number;
|
penaltyDecay?: number;
|
||||||
|
globalPenalty?: boolean;
|
||||||
}
|
}
|
||||||
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'WebGPU (Python)' | 'MPS' | 'Custom';
|
export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'WebGPU (Python)' | 'MPS' | 'Custom';
|
||||||
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';
|
export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';
|
||||||
|
Loading…
Reference in New Issue
Block a user