upgrade to rwkv 0.8.26 (state instruct align support)

This commit is contained in:
josc146 2024-04-30 22:24:22 +08:00
parent 70236df3d1
commit 38b33a7030
3 changed files with 34 additions and 18 deletions

View File

@ -1,7 +1,7 @@
torch torch
torchvision torchvision
torchaudio torchaudio
rwkv==0.8.25 rwkv==0.8.26
langchain==0.0.322 langchain==0.0.322
fastapi==0.109.1 fastapi==0.109.1
uvicorn==0.23.2 uvicorn==0.23.2

View File

@ -1,7 +1,7 @@
torch torch
torchvision torchvision
torchaudio torchaudio
rwkv==0.8.25 rwkv==0.8.26
langchain==0.0.322 langchain==0.0.322
fastapi==0.109.1 fastapi==0.109.1
uvicorn==0.23.2 uvicorn==0.23.2

View File

@ -488,14 +488,19 @@ class RWKV(MyModule):
print_need_newline = False print_need_newline = False
REAL_TIME_FIRST = False REAL_TIME_FIRST = False
args.time_state = False
for x in list(w.keys()): for x in list(w.keys()):
if ".time_faaaa" in x: if ".time_faaaa" in x:
REAL_TIME_FIRST = True REAL_TIME_FIRST = True
if ".time_state" in x:
args.time_state = True
if REAL_TIME_FIRST: if REAL_TIME_FIRST:
w = { w = {
(
k.replace(".time_faaaa", ".time_first") k.replace(".time_faaaa", ".time_first")
if ".time_faaaa" in k if ".time_faaaa" in k
else k: v else k
): v
for k, v in w.items() for k, v in w.items()
} }
self.w = w self.w = w
@ -631,8 +636,10 @@ class RWKV(MyModule):
torch.cuda.empty_cache() torch.cuda.empty_cache()
shape = [i for i in w[x].shape if i != 1] shape = [i for i in w[x].shape if i != 1]
if len(shape) > 1: if len(shape) > 2:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}"
elif len(shape) > 1:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} "
else: else:
shape = f" {str(shape[0]).rjust(5)} " shape = f" {str(shape[0]).rjust(5)} "
if layer_id == 0 or layer_id >= args.n_layer - 1: if layer_id == 0 or layer_id >= args.n_layer - 1:
@ -2108,6 +2115,15 @@ class RWKV(MyModule):
state[i * 3 + 0] = torch.zeros( state[i * 3 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous() ).contiguous()
if args.time_state:
state[i * 3 + 1] = (
w[f"blocks.{i}.att.time_state"]
.transpose(1, 2)
.to(dtype=torch.float, device=dev)
.requires_grad_(False)
.contiguous()
)
else:
state[i * 3 + 1] = torch.zeros( state[i * 3 + 1] = torch.zeros(
( (
args.n_head, args.n_head,