upgrade to rwkv 0.8.26 (state instruct align support)

This commit is contained in:
josc146 2024-04-30 22:24:22 +08:00
parent 70236df3d1
commit 38b33a7030
3 changed files with 34 additions and 18 deletions

View File

@ -1,7 +1,7 @@
torch torch
torchvision torchvision
torchaudio torchaudio
rwkv==0.8.25 rwkv==0.8.26
langchain==0.0.322 langchain==0.0.322
fastapi==0.109.1 fastapi==0.109.1
uvicorn==0.23.2 uvicorn==0.23.2

View File

@ -1,7 +1,7 @@
torch torch
torchvision torchvision
torchaudio torchaudio
rwkv==0.8.25 rwkv==0.8.26
langchain==0.0.322 langchain==0.0.322
fastapi==0.109.1 fastapi==0.109.1
uvicorn==0.23.2 uvicorn==0.23.2

View File

@ -488,14 +488,19 @@ class RWKV(MyModule):
print_need_newline = False print_need_newline = False
REAL_TIME_FIRST = False REAL_TIME_FIRST = False
args.time_state = False
for x in list(w.keys()): for x in list(w.keys()):
if ".time_faaaa" in x: if ".time_faaaa" in x:
REAL_TIME_FIRST = True REAL_TIME_FIRST = True
if ".time_state" in x:
args.time_state = True
if REAL_TIME_FIRST: if REAL_TIME_FIRST:
w = { w = {
k.replace(".time_faaaa", ".time_first") (
if ".time_faaaa" in k k.replace(".time_faaaa", ".time_first")
else k: v if ".time_faaaa" in k
else k
): v
for k, v in w.items() for k, v in w.items()
} }
self.w = w self.w = w
@ -631,10 +636,12 @@ class RWKV(MyModule):
torch.cuda.empty_cache() torch.cuda.empty_cache()
shape = [i for i in w[x].shape if i != 1] shape = [i for i in w[x].shape if i != 1]
if len(shape) > 1: if len(shape) > 2:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}"
elif len(shape) > 1:
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} "
else: else:
shape = f" {str(shape[0]).rjust(5)} " shape = f" {str(shape[0]).rjust(5)} "
if layer_id == 0 or layer_id >= args.n_layer - 1: if layer_id == 0 or layer_id >= args.n_layer - 1:
if print_need_newline: if print_need_newline:
prxxx("\n", end="") prxxx("\n", end="")
@ -2108,16 +2115,25 @@ class RWKV(MyModule):
state[i * 3 + 0] = torch.zeros( state[i * 3 + 0] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous() ).contiguous()
state[i * 3 + 1] = torch.zeros( if args.time_state:
( state[i * 3 + 1] = (
args.n_head, w[f"blocks.{i}.att.time_state"]
args.n_att // args.n_head, .transpose(1, 2)
args.n_att // args.n_head, .to(dtype=torch.float, device=dev)
), .requires_grad_(False)
dtype=torch.float, .contiguous()
requires_grad=False, )
device=dev, else:
).contiguous() state[i * 3 + 1] = torch.zeros(
(
args.n_head,
args.n_att // args.n_head,
args.n_att // args.n_head,
),
dtype=torch.float,
requires_grad=False,
device=dev,
).contiguous()
state[i * 3 + 2] = torch.zeros( state[i * 3 + 2] = torch.zeros(
args.n_embd, dtype=atype, requires_grad=False, device=dev args.n_embd, dtype=atype, requires_grad=False, device=dev
).contiguous() ).contiguous()