upgrade to rwkv 0.8.25

This commit is contained in:
josc146
2024-02-21 23:50:05 +08:00
parent 85b10993ec
commit fafd9f7f6e
4 changed files with 17 additions and 5 deletions

View File

@@ -552,7 +552,12 @@ class RWKV(MyModule):
elif ".ln_x" in x: # need fp32 for group_norm
w[x] = w[x].float()
else:
if (len(w[x].shape) == 2) and ("emb" not in x):
if (
(len(w[x].shape) == 2)
and ("emb" not in x)
and ("_w1" not in x)
and ("_w2" not in x)
):
if WTYPE != torch.uint8:
w[x] = w[x].to(dtype=WTYPE)
else: