upgrade to rwkv 0.8.25
This commit is contained in:
7
backend-python/rwkv_pip/model.py
vendored
7
backend-python/rwkv_pip/model.py
vendored
@@ -552,7 +552,12 @@ class RWKV(MyModule):
|
||||
elif ".ln_x" in x: # need fp32 for group_norm
|
||||
w[x] = w[x].float()
|
||||
else:
|
||||
if (len(w[x].shape) == 2) and ("emb" not in x):
|
||||
if (
|
||||
(len(w[x].shape) == 2)
|
||||
and ("emb" not in x)
|
||||
and ("_w1" not in x)
|
||||
and ("_w2" not in x)
|
||||
):
|
||||
if WTYPE != torch.uint8:
|
||||
w[x] = w[x].to(dtype=WTYPE)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user