upgrade to rwkv 0.8.25
This commit is contained in:
parent
85b10993ec
commit
fafd9f7f6e
@ -1,7 +1,7 @@
|
|||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
torchaudio
|
torchaudio
|
||||||
rwkv==0.8.22
|
rwkv==0.8.25
|
||||||
langchain==0.0.322
|
langchain==0.0.322
|
||||||
fastapi==0.104.0
|
fastapi==0.104.0
|
||||||
uvicorn==0.23.2
|
uvicorn==0.23.2
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
torchaudio
|
torchaudio
|
||||||
rwkv==0.8.22
|
rwkv==0.8.25
|
||||||
langchain==0.0.322
|
langchain==0.0.322
|
||||||
fastapi==0.104.0
|
fastapi==0.104.0
|
||||||
uvicorn==0.23.2
|
uvicorn==0.23.2
|
||||||
|
7
backend-python/rwkv_pip/model.py
vendored
7
backend-python/rwkv_pip/model.py
vendored
@ -552,7 +552,12 @@ class RWKV(MyModule):
|
|||||||
elif ".ln_x" in x: # need fp32 for group_norm
|
elif ".ln_x" in x: # need fp32 for group_norm
|
||||||
w[x] = w[x].float()
|
w[x] = w[x].float()
|
||||||
else:
|
else:
|
||||||
if (len(w[x].shape) == 2) and ("emb" not in x):
|
if (
|
||||||
|
(len(w[x].shape) == 2)
|
||||||
|
and ("emb" not in x)
|
||||||
|
and ("_w1" not in x)
|
||||||
|
and ("_w2" not in x)
|
||||||
|
):
|
||||||
if WTYPE != torch.uint8:
|
if WTYPE != torch.uint8:
|
||||||
w[x] = w[x].to(dtype=WTYPE)
|
w[x] = w[x].to(dtype=WTYPE)
|
||||||
else:
|
else:
|
||||||
|
11
backend-python/rwkv_pip/utils.py
vendored
11
backend-python/rwkv_pip/utils.py
vendored
@ -171,10 +171,17 @@ class PIPELINE:
|
|||||||
all_tokens += [token]
|
all_tokens += [token]
|
||||||
for xxx in occurrence:
|
for xxx in occurrence:
|
||||||
occurrence[xxx] *= args.alpha_decay
|
occurrence[xxx] *= args.alpha_decay
|
||||||
|
|
||||||
|
ttt = self.decode([token])
|
||||||
|
www = 1
|
||||||
|
if ttt in " \t0123456789":
|
||||||
|
www = 0
|
||||||
|
# elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
|
||||||
|
# www = 0.5
|
||||||
if token not in occurrence:
|
if token not in occurrence:
|
||||||
occurrence[token] = 1
|
occurrence[token] = www
|
||||||
else:
|
else:
|
||||||
occurrence[token] += 1
|
occurrence[token] += www
|
||||||
# print(occurrence) # debug
|
# print(occurrence) # debug
|
||||||
|
|
||||||
# output
|
# output
|
||||||
|
Loading…
Reference in New Issue
Block a user