upgrade to rwkv 0.8.25
This commit is contained in:
parent
85b10993ec
commit
fafd9f7f6e
@ -1,7 +1,7 @@
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
rwkv==0.8.22
|
||||
rwkv==0.8.25
|
||||
langchain==0.0.322
|
||||
fastapi==0.104.0
|
||||
uvicorn==0.23.2
|
||||
|
@ -1,7 +1,7 @@
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
rwkv==0.8.22
|
||||
rwkv==0.8.25
|
||||
langchain==0.0.322
|
||||
fastapi==0.104.0
|
||||
uvicorn==0.23.2
|
||||
|
7
backend-python/rwkv_pip/model.py
vendored
7
backend-python/rwkv_pip/model.py
vendored
@ -552,7 +552,12 @@ class RWKV(MyModule):
|
||||
elif ".ln_x" in x: # need fp32 for group_norm
|
||||
w[x] = w[x].float()
|
||||
else:
|
||||
if (len(w[x].shape) == 2) and ("emb" not in x):
|
||||
if (
|
||||
(len(w[x].shape) == 2)
|
||||
and ("emb" not in x)
|
||||
and ("_w1" not in x)
|
||||
and ("_w2" not in x)
|
||||
):
|
||||
if WTYPE != torch.uint8:
|
||||
w[x] = w[x].to(dtype=WTYPE)
|
||||
else:
|
||||
|
11
backend-python/rwkv_pip/utils.py
vendored
11
backend-python/rwkv_pip/utils.py
vendored
@ -171,10 +171,17 @@ class PIPELINE:
|
||||
all_tokens += [token]
|
||||
for xxx in occurrence:
|
||||
occurrence[xxx] *= args.alpha_decay
|
||||
|
||||
ttt = self.decode([token])
|
||||
www = 1
|
||||
if ttt in " \t0123456789":
|
||||
www = 0
|
||||
# elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
|
||||
# www = 0.5
|
||||
if token not in occurrence:
|
||||
occurrence[token] = 1
|
||||
occurrence[token] = www
|
||||
else:
|
||||
occurrence[token] += 1
|
||||
occurrence[token] += www
|
||||
# print(occurrence) # debug
|
||||
|
||||
# output
|
||||
|
Loading…
Reference in New Issue
Block a user