upgrade to rwkv 0.8.26 (state instruct align support)
This commit is contained in:
parent
70236df3d1
commit
38b33a7030
@ -1,7 +1,7 @@
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
rwkv==0.8.25
|
||||
rwkv==0.8.26
|
||||
langchain==0.0.322
|
||||
fastapi==0.109.1
|
||||
uvicorn==0.23.2
|
||||
|
@ -1,7 +1,7 @@
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
rwkv==0.8.25
|
||||
rwkv==0.8.26
|
||||
langchain==0.0.322
|
||||
fastapi==0.109.1
|
||||
uvicorn==0.23.2
|
||||
|
48
backend-python/rwkv_pip/model.py
vendored
48
backend-python/rwkv_pip/model.py
vendored
@ -488,14 +488,19 @@ class RWKV(MyModule):
|
||||
print_need_newline = False
|
||||
|
||||
REAL_TIME_FIRST = False
|
||||
args.time_state = False
|
||||
for x in list(w.keys()):
|
||||
if ".time_faaaa" in x:
|
||||
REAL_TIME_FIRST = True
|
||||
if ".time_state" in x:
|
||||
args.time_state = True
|
||||
if REAL_TIME_FIRST:
|
||||
w = {
|
||||
k.replace(".time_faaaa", ".time_first")
|
||||
if ".time_faaaa" in k
|
||||
else k: v
|
||||
(
|
||||
k.replace(".time_faaaa", ".time_first")
|
||||
if ".time_faaaa" in k
|
||||
else k
|
||||
): v
|
||||
for k, v in w.items()
|
||||
}
|
||||
self.w = w
|
||||
@ -631,10 +636,12 @@ class RWKV(MyModule):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
shape = [i for i in w[x].shape if i != 1]
|
||||
if len(shape) > 1:
|
||||
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}"
|
||||
if len(shape) > 2:
|
||||
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}"
|
||||
elif len(shape) > 1:
|
||||
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} "
|
||||
else:
|
||||
shape = f" {str(shape[0]).rjust(5)} "
|
||||
shape = f" {str(shape[0]).rjust(5)} "
|
||||
if layer_id == 0 or layer_id >= args.n_layer - 1:
|
||||
if print_need_newline:
|
||||
prxxx("\n", end="")
|
||||
@ -2108,16 +2115,25 @@ class RWKV(MyModule):
|
||||
state[i * 3 + 0] = torch.zeros(
|
||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||
).contiguous()
|
||||
state[i * 3 + 1] = torch.zeros(
|
||||
(
|
||||
args.n_head,
|
||||
args.n_att // args.n_head,
|
||||
args.n_att // args.n_head,
|
||||
),
|
||||
dtype=torch.float,
|
||||
requires_grad=False,
|
||||
device=dev,
|
||||
).contiguous()
|
||||
if args.time_state:
|
||||
state[i * 3 + 1] = (
|
||||
w[f"blocks.{i}.att.time_state"]
|
||||
.transpose(1, 2)
|
||||
.to(dtype=torch.float, device=dev)
|
||||
.requires_grad_(False)
|
||||
.contiguous()
|
||||
)
|
||||
else:
|
||||
state[i * 3 + 1] = torch.zeros(
|
||||
(
|
||||
args.n_head,
|
||||
args.n_att // args.n_head,
|
||||
args.n_att // args.n_head,
|
||||
),
|
||||
dtype=torch.float,
|
||||
requires_grad=False,
|
||||
device=dev,
|
||||
).contiguous()
|
||||
state[i * 3 + 2] = torch.zeros(
|
||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||
).contiguous()
|
||||
|
Loading…
Reference in New Issue
Block a user