From 4f14074a75b5a10e7415c5f88dfa05d4acd0c283 Mon Sep 17 00:00:00 2001 From: josc146 Date: Sat, 2 Mar 2024 17:50:41 +0800 Subject: [PATCH] expose global_penalty --- backend-python/utils/rwkv.py | 10 ++- frontend/src/_locales/ja/main.json | 4 +- frontend/src/_locales/zh-hans/main.json | 4 +- frontend/src/pages/Configs.tsx | 99 ++++++++++++++++++------- frontend/src/stores/commonStore.ts | 5 ++ frontend/src/types/configs.ts | 1 + 6 files changed, 92 insertions(+), 31 deletions(-) diff --git a/backend-python/utils/rwkv.py b/backend-python/utils/rwkv.py index 2fff4ee..8a0cd7b 100644 --- a/backend-python/utils/rwkv.py +++ b/backend-python/utils/rwkv.py @@ -40,6 +40,7 @@ class AbstractRWKV(ABC): self.penalty_alpha_presence = 0 self.penalty_alpha_frequency = 1 self.penalty_decay = 0.996 + self.global_penalty = False @abstractmethod def adjust_occurrence(self, occurrence: Dict, token: int): @@ -403,8 +404,8 @@ class TextRWKV(AbstractRWKV): + occurrence[n] * self.penalty_alpha_frequency ) - # comment the codes below to get the same generated results as the official RWKV Gradio - if i == 0: + # set global_penalty to False to get the same generated results as the official RWKV Gradio + if self.global_penalty and i == 0: for token in self.model_tokens: token = int(token) if token not in self.AVOID_PENALTY_TOKENS: @@ -673,6 +674,7 @@ class ModelConfigBody(BaseModel): frequency_penalty: float = Field(default=None, ge=-2, le=2) penalty_decay: float = Field(default=None, ge=0.99, le=0.999) top_k: int = Field(default=None, ge=0, le=25) + global_penalty: bool = Field(default=None) model_config = { "json_schema_extra": { @@ -683,6 +685,7 @@ class ModelConfigBody(BaseModel): "presence_penalty": 0, "frequency_penalty": 1, "penalty_decay": 0.996, + "global_penalty": False, } } } @@ -706,6 +709,8 @@ def set_rwkv_config(model: AbstractRWKV, body: ModelConfigBody): model.penalty_decay = body.penalty_decay if body.top_k is not None: model.top_k = body.top_k + if body.global_penalty is not None: + model.global_penalty = body.global_penalty def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody: @@ -717,4 +722,5 @@ def get_rwkv_config(model: AbstractRWKV) -> ModelConfigBody: frequency_penalty=model.penalty_alpha_frequency, penalty_decay=model.penalty_decay, top_k=model.top_k, + global_penalty=model.global_penalty, ) diff --git a/frontend/src/_locales/ja/main.json b/frontend/src/_locales/ja/main.json index 419c625..cb11537 100644 --- a/frontend/src/_locales/ja/main.json +++ b/frontend/src/_locales/ja/main.json @@ -345,5 +345,7 @@ "Quantized Layers": "量子化されたレイヤー", "Number of the neural network layers quantized with current precision, the more you quantize, the lower the VRAM usage, but the quality correspondingly decreases.": "現在の精度で量子化されたニューラルネットワークのレイヤーの数、量子化するほどVRAMの使用量が低くなりますが、品質も相応に低下します。", "Parallel Token Chunk Size": "並列トークンチャンクサイズ", - "Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一度に並列で処理される最大トークン数。高性能なGPUの場合、64または128になります(高速)。" + "Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一度に並列で処理される最大トークン数。高性能なGPUの場合、64または128になります(高速)。", + "Global Penalty": "グローバルペナルティ", + "When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.": "レスポンスを生成する際、提出されたプロンプトをペナルティ要因として含めるかどうか。これをオフにすると、公式RWKV Gradioと同じ生成結果を得ることができます。生成された結果に重複がある場合、これをオンにすることで重複の生成を回避するのに役立ちます。" } \ No newline at end of file diff --git a/frontend/src/_locales/zh-hans/main.json b/frontend/src/_locales/zh-hans/main.json index 32d7929..9329348 100644 --- a/frontend/src/_locales/zh-hans/main.json +++ b/frontend/src/_locales/zh-hans/main.json @@ -345,5 +345,7 @@ "Quantized Layers": "量化层数", "Number of the neural network layers quantized with current precision, the more you quantize, the lower the VRAM usage, but the quality correspondingly decreases.": "神经网络以当前精度量化的层数, 量化越多, 占用显存越低, 但质量相应下降", "Parallel Token Chunk Size": "并行Token块大小", - "Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一次最多可以并行处理的token数量. 对于高端显卡, 这可以是64或128 (更快)" + "Maximum tokens to be processed in parallel at once. For high end GPUs, this could be 64 or 128 (faster).": "一次最多可以并行处理的token数量. 对于高端显卡, 这可以是64或128 (更快)", + "Global Penalty": "全局惩罚", + "When generating a response, whether to include the submitted prompt as a penalty factor. By turning this off, you will get the same generated results as official RWKV Gradio. If you find duplicate results in the generated results, turning this on can help avoid generating duplicates.": "生成响应时, 是否将提交的prompt也纳入到惩罚项. 关闭此项将得到与RWKV官方Gradio完全一致的生成结果. 如果你发现生成结果出现重复, 那么开启此项有助于避免生成重复" } \ No newline at end of file diff --git a/frontend/src/pages/Configs.tsx b/frontend/src/pages/Configs.tsx index 62098ba..f6e14df 100644 --- a/frontend/src/pages/Configs.tsx +++ b/frontend/src/pages/Configs.tsx @@ -35,6 +35,7 @@ import { ResetConfigsButton } from '../components/ResetConfigsButton'; import { useMediaQuery } from 'usehooks-ts'; import { ApiParameters, Device, ModelParameters, Precision } from '../types/configs'; import { convertModel, convertToGGML, convertToSt } from '../utils/convert-model'; +import { defaultPenaltyDecay } from './defaultConfigs'; const ConfigSelector: FC<{ selectedIndex: number, @@ -66,14 +67,17 @@ const Configs: FC = observer(() => { const [selectedIndex, setSelectedIndex] = React.useState(commonStore.currentModelConfigIndex); const [selectedConfig, setSelectedConfig] = React.useState(commonStore.modelConfigs[selectedIndex]); const [displayStrategyImg, setDisplayStrategyImg] = React.useState(false); - const advancedHeaderRef = useRef(null); + const advancedHeaderRef1 = useRef(null); + const advancedHeaderRef2 = useRef(null); const mq = useMediaQuery('(min-width: 640px)'); const navigate = useNavigate(); const port = selectedConfig.apiParameters.apiPort; useEffect(() => { - if (advancedHeaderRef.current) - (advancedHeaderRef.current.firstElementChild as HTMLElement).style.padding = '0'; + if (advancedHeaderRef1.current) + (advancedHeaderRef1.current.firstElementChild as HTMLElement).style.padding = '0'; + if (advancedHeaderRef2.current) + (advancedHeaderRef2.current.firstElementChild as HTMLElement).style.padding = '0'; }, []); const updateSelectedIndex = useCallback((newIndex: number) => { @@ -113,7 +117,9 @@ const Configs: FC = observer(() => { temperature: selectedConfig.apiParameters.temperature, top_p: selectedConfig.apiParameters.topP, presence_penalty: selectedConfig.apiParameters.presencePenalty, - frequency_penalty: selectedConfig.apiParameters.frequencyPenalty + frequency_penalty: selectedConfig.apiParameters.frequencyPenalty, + penalty_decay: selectedConfig.apiParameters.penaltyDecay, + global_penalty: selectedConfig.apiParameters.globalPenalty }); toast(t('Config Saved'), { autoClose: 300, type: 'success' }); }; @@ -194,28 +200,67 @@ const Configs: FC = observer(() => { }); }} /> } /> - { - setSelectedConfigApiParams({ - presencePenalty: data.value - }); - }} /> - } /> - { - setSelectedConfigApiParams({ - frequencyPenalty: data.value - }); - }} /> - } /> + { + if (data.value === 'advanced') + commonStore.setApiParamsCollapsed(!commonStore.apiParamsCollapsed); + }}> + + {t('Advanced')} + +
+ { + setSelectedConfigApiParams({ + presencePenalty: data.value + }); + }} /> + } /> + { + setSelectedConfigApiParams({ + frequencyPenalty: data.value + }); + }} /> + } /> + { + setSelectedConfigApiParams({ + penaltyDecay: data.value + }); + }} /> + } /> + { + setSelectedConfigApiParams({ + globalPenalty: data.checked + }); + }} /> + } /> +
+
+
+
} /> @@ -410,7 +455,7 @@ const Configs: FC = observer(() => { commonStore.setModelParamsCollapsed(!commonStore.modelParamsCollapsed); }}> - {t('Advanced')} + {t('Advanced')}
diff --git a/frontend/src/stores/commonStore.ts b/frontend/src/stores/commonStore.ts index 1027d7a..d7e6cf8 100644 --- a/frontend/src/stores/commonStore.ts +++ b/frontend/src/stores/commonStore.ts @@ -127,6 +127,7 @@ class CommonStore { // configs currentModelConfigIndex: number = 0; modelConfigs: ModelConfig[] = []; + apiParamsCollapsed: boolean = true; modelParamsCollapsed: boolean = true; // models activeModelListTags: string[] = []; @@ -324,6 +325,10 @@ class CommonStore { this.advancedCollapsed = value; } + setApiParamsCollapsed(value: boolean) { + this.apiParamsCollapsed = value; + } + setModelParamsCollapsed(value: boolean) { this.modelParamsCollapsed = value; } diff --git a/frontend/src/types/configs.ts b/frontend/src/types/configs.ts index 852e58a..4a6526d 100644 --- a/frontend/src/types/configs.ts +++ b/frontend/src/types/configs.ts @@ -6,6 +6,7 @@ export type ApiParameters = { presencePenalty: number; frequencyPenalty: number; penaltyDecay?: number; + globalPenalty?: boolean; } export type Device = 'CPU' | 'CPU (rwkv.cpp)' | 'CUDA' | 'CUDA-Beta' | 'WebGPU' | 'WebGPU (Python)' | 'MPS' | 'Custom'; export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1';