upgrade to rwkv 0.8.26 (state instruct align support)

This commit is contained in:
josc146 2024-04-30 22:24:22 +08:00
parent 70236df3d1
commit 38b33a7030
3 changed files with 34 additions and 18 deletions

View File

@ -1,7 +1,7 @@
torch
torchvision
torchaudio
rwkv==0.8.25
rwkv==0.8.26
langchain==0.0.322
fastapi==0.109.1
uvicorn==0.23.2

View File

@ -1,7 +1,7 @@
torch
torchvision
torchaudio
rwkv==0.8.25
rwkv==0.8.26
langchain==0.0.322
fastapi==0.109.1
uvicorn==0.23.2

View File

@ -488,14 +488,19 @@ class RWKV(MyModule):
print_need_newline = False
REAL_TIME_FIRST = False
args.time_state = False
for x in list(w.keys()):
if ".time_faaaa" in x:
REAL_TIME_FIRST = True
if ".time_state" in x:
args.time_state = True
if REAL_TIME_FIRST:
w = {
(
k.replace(".time_faaaa", ".time_first")
if ".time_faaaa" in k
else k: v
else k
): v
for k, v in w.items()
}
self.w = w
@ -631,7 +636,9 @@ class RWKV(MyModule):
torch.cuda.empty_cache()
shape = [i for i in w[x].shape if i != 1]
if len(shape) > 1:
if len(shape) > 2:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}"
elif len(shape) > 1:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} "
else:
shape = f" {str(shape[0]).rjust(5)} "
@ -2108,6 +2115,15 @@ class RWKV(MyModule):
state[i * 3 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous()
if args.time_state:
state[i * 3 + 1] = (
w[f"blocks.{i}.att.time_state"]
.transpose(1, 2)
.to(dtype=torch.float, device=dev)
.requires_grad_(False)
.contiguous()
)
else:
state[i * 3 + 1] = torch.zeros(
(
args.n_head,