diff --git a/backend-python/requirements.txt b/backend-python/requirements.txt index 49d2a73..83f5491 100644 --- a/backend-python/requirements.txt +++ b/backend-python/requirements.txt @@ -1,7 +1,7 @@ torch torchvision torchaudio -rwkv==0.8.22 +rwkv==0.8.25 langchain==0.0.322 fastapi==0.104.0 uvicorn==0.23.2 diff --git a/backend-python/requirements_without_cyac.txt b/backend-python/requirements_without_cyac.txt index f4a9f05..54f2794 100644 --- a/backend-python/requirements_without_cyac.txt +++ b/backend-python/requirements_without_cyac.txt @@ -1,7 +1,7 @@ torch torchvision torchaudio -rwkv==0.8.22 +rwkv==0.8.25 langchain==0.0.322 fastapi==0.104.0 uvicorn==0.23.2 diff --git a/backend-python/rwkv_pip/model.py b/backend-python/rwkv_pip/model.py index dccfc08..9f495d5 100644 --- a/backend-python/rwkv_pip/model.py +++ b/backend-python/rwkv_pip/model.py @@ -552,7 +552,12 @@ class RWKV(MyModule): elif ".ln_x" in x: # need fp32 for group_norm w[x] = w[x].float() else: - if (len(w[x].shape) == 2) and ("emb" not in x): + if ( + (len(w[x].shape) == 2) + and ("emb" not in x) + and ("_w1" not in x) + and ("_w2" not in x) + ): if WTYPE != torch.uint8: w[x] = w[x].to(dtype=WTYPE) else: diff --git a/backend-python/rwkv_pip/utils.py b/backend-python/rwkv_pip/utils.py index 23bb8db..969b977 100644 --- a/backend-python/rwkv_pip/utils.py +++ b/backend-python/rwkv_pip/utils.py @@ -171,10 +171,17 @@ class PIPELINE: all_tokens += [token] for xxx in occurrence: occurrence[xxx] *= args.alpha_decay + + ttt = self.decode([token]) + www = 1 + if ttt in " \t0123456789": + www = 0 + # elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】': + # www = 0.5 if token not in occurrence: - occurrence[token] = 1 + occurrence[token] = www else: - occurrence[token] += 1 + occurrence[token] += www # print(occurrence) # debug # output