From fafd9f7f6e322515123eea1a1093032015d7eb61 Mon Sep 17 00:00:00 2001 From: josc146 Date: Wed, 21 Feb 2024 23:50:05 +0800 Subject: [PATCH] upgrade to rwkv 0.8.25 --- backend-python/requirements.txt | 2 +- backend-python/requirements_without_cyac.txt | 2 +- backend-python/rwkv_pip/model.py | 7 ++++++- backend-python/rwkv_pip/utils.py | 11 +++++++++-- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/backend-python/requirements.txt b/backend-python/requirements.txt index 49d2a73..83f5491 100644 --- a/backend-python/requirements.txt +++ b/backend-python/requirements.txt @@ -1,7 +1,7 @@ torch torchvision torchaudio -rwkv==0.8.22 +rwkv==0.8.25 langchain==0.0.322 fastapi==0.104.0 uvicorn==0.23.2 diff --git a/backend-python/requirements_without_cyac.txt b/backend-python/requirements_without_cyac.txt index f4a9f05..54f2794 100644 --- a/backend-python/requirements_without_cyac.txt +++ b/backend-python/requirements_without_cyac.txt @@ -1,7 +1,7 @@ torch torchvision torchaudio -rwkv==0.8.22 +rwkv==0.8.25 langchain==0.0.322 fastapi==0.104.0 uvicorn==0.23.2 diff --git a/backend-python/rwkv_pip/model.py b/backend-python/rwkv_pip/model.py index dccfc08..9f495d5 100644 --- a/backend-python/rwkv_pip/model.py +++ b/backend-python/rwkv_pip/model.py @@ -552,7 +552,12 @@ class RWKV(MyModule): elif ".ln_x" in x: # need fp32 for group_norm w[x] = w[x].float() else: - if (len(w[x].shape) == 2) and ("emb" not in x): + if ( + (len(w[x].shape) == 2) + and ("emb" not in x) + and ("_w1" not in x) + and ("_w2" not in x) + ): if WTYPE != torch.uint8: w[x] = w[x].to(dtype=WTYPE) else: diff --git a/backend-python/rwkv_pip/utils.py b/backend-python/rwkv_pip/utils.py index 23bb8db..969b977 100644 --- a/backend-python/rwkv_pip/utils.py +++ b/backend-python/rwkv_pip/utils.py @@ -171,10 +171,17 @@ class PIPELINE: all_tokens += [token] for xxx in occurrence: occurrence[xxx] *= args.alpha_decay + + ttt = self.decode([token]) + www = 1 + if ttt in " \t0123456789": + www = 0 + # elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】': + # www = 0.5 if token not in occurrence: - occurrence[token] = 1 + occurrence[token] = www else: - occurrence[token] += 1 + occurrence[token] += www # print(occurrence) # debug # output