upgrade to rwkv 0.8.25

This commit is contained in:
josc146 2024-02-21 23:50:05 +08:00
parent 85b10993ec
commit fafd9f7f6e
4 changed files with 17 additions and 5 deletions

View File

@ -1,7 +1,7 @@
torch
torchvision
torchaudio
rwkv==0.8.22
rwkv==0.8.25
langchain==0.0.322
fastapi==0.104.0
uvicorn==0.23.2

View File

@ -1,7 +1,7 @@
torch
torchvision
torchaudio
rwkv==0.8.22
rwkv==0.8.25
langchain==0.0.322
fastapi==0.104.0
uvicorn==0.23.2

View File

@ -552,7 +552,12 @@ class RWKV(MyModule):
elif ".ln_x" in x: # need fp32 for group_norm
w[x] = w[x].float()
else:
if (len(w[x].shape) == 2) and ("emb" not in x):
if (
(len(w[x].shape) == 2)
and ("emb" not in x)
and ("_w1" not in x)
and ("_w2" not in x)
):
if WTYPE != torch.uint8:
w[x] = w[x].to(dtype=WTYPE)
else:

View File

@ -171,10 +171,17 @@ class PIPELINE:
all_tokens += [token]
for xxx in occurrence:
occurrence[xxx] *= args.alpha_decay
ttt = self.decode([token])
www = 1
if ttt in " \t0123456789":
www = 0
# elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
# www = 0.5
if token not in occurrence:
occurrence[token] = 1
occurrence[token] = www
else:
occurrence[token] += 1
occurrence[token] += www
# print(occurrence) # debug
# output