upgrade to rwkv 0.8.25

This commit is contained in:
josc146 2024-02-21 23:50:05 +08:00
parent 85b10993ec
commit fafd9f7f6e
4 changed files with 17 additions and 5 deletions

View File

@ -1,7 +1,7 @@
torch torch
torchvision torchvision
torchaudio torchaudio
rwkv==0.8.22 rwkv==0.8.25
langchain==0.0.322 langchain==0.0.322
fastapi==0.104.0 fastapi==0.104.0
uvicorn==0.23.2 uvicorn==0.23.2

View File

@ -1,7 +1,7 @@
torch torch
torchvision torchvision
torchaudio torchaudio
rwkv==0.8.22 rwkv==0.8.25
langchain==0.0.322 langchain==0.0.322
fastapi==0.104.0 fastapi==0.104.0
uvicorn==0.23.2 uvicorn==0.23.2

View File

@ -552,7 +552,12 @@ class RWKV(MyModule):
elif ".ln_x" in x: # need fp32 for group_norm elif ".ln_x" in x: # need fp32 for group_norm
w[x] = w[x].float() w[x] = w[x].float()
else: else:
if (len(w[x].shape) == 2) and ("emb" not in x): if (
(len(w[x].shape) == 2)
and ("emb" not in x)
and ("_w1" not in x)
and ("_w2" not in x)
):
if WTYPE != torch.uint8: if WTYPE != torch.uint8:
w[x] = w[x].to(dtype=WTYPE) w[x] = w[x].to(dtype=WTYPE)
else: else:

View File

@ -171,10 +171,17 @@ class PIPELINE:
all_tokens += [token] all_tokens += [token]
for xxx in occurrence: for xxx in occurrence:
occurrence[xxx] *= args.alpha_decay occurrence[xxx] *= args.alpha_decay
ttt = self.decode([token])
www = 1
if ttt in " \t0123456789":
www = 0
# elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
# www = 0.5
if token not in occurrence: if token not in occurrence:
occurrence[token] = 1 occurrence[token] = www
else: else:
occurrence[token] += 1 occurrence[token] += www
# print(occurrence) # debug # print(occurrence) # debug
# output # output