upgrade to rwkv 0.8.26 (state instruct align support)
This commit is contained in:
parent
70236df3d1
commit
38b33a7030
@ -1,7 +1,7 @@
|
|||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
torchaudio
|
torchaudio
|
||||||
rwkv==0.8.25
|
rwkv==0.8.26
|
||||||
langchain==0.0.322
|
langchain==0.0.322
|
||||||
fastapi==0.109.1
|
fastapi==0.109.1
|
||||||
uvicorn==0.23.2
|
uvicorn==0.23.2
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
torchaudio
|
torchaudio
|
||||||
rwkv==0.8.25
|
rwkv==0.8.26
|
||||||
langchain==0.0.322
|
langchain==0.0.322
|
||||||
fastapi==0.109.1
|
fastapi==0.109.1
|
||||||
uvicorn==0.23.2
|
uvicorn==0.23.2
|
||||||
|
48
backend-python/rwkv_pip/model.py
vendored
48
backend-python/rwkv_pip/model.py
vendored
@ -488,14 +488,19 @@ class RWKV(MyModule):
|
|||||||
print_need_newline = False
|
print_need_newline = False
|
||||||
|
|
||||||
REAL_TIME_FIRST = False
|
REAL_TIME_FIRST = False
|
||||||
|
args.time_state = False
|
||||||
for x in list(w.keys()):
|
for x in list(w.keys()):
|
||||||
if ".time_faaaa" in x:
|
if ".time_faaaa" in x:
|
||||||
REAL_TIME_FIRST = True
|
REAL_TIME_FIRST = True
|
||||||
|
if ".time_state" in x:
|
||||||
|
args.time_state = True
|
||||||
if REAL_TIME_FIRST:
|
if REAL_TIME_FIRST:
|
||||||
w = {
|
w = {
|
||||||
k.replace(".time_faaaa", ".time_first")
|
(
|
||||||
if ".time_faaaa" in k
|
k.replace(".time_faaaa", ".time_first")
|
||||||
else k: v
|
if ".time_faaaa" in k
|
||||||
|
else k
|
||||||
|
): v
|
||||||
for k, v in w.items()
|
for k, v in w.items()
|
||||||
}
|
}
|
||||||
self.w = w
|
self.w = w
|
||||||
@ -631,10 +636,12 @@ class RWKV(MyModule):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
shape = [i for i in w[x].shape if i != 1]
|
shape = [i for i in w[x].shape if i != 1]
|
||||||
if len(shape) > 1:
|
if len(shape) > 2:
|
||||||
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}"
|
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}"
|
||||||
|
elif len(shape) > 1:
|
||||||
|
shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} "
|
||||||
else:
|
else:
|
||||||
shape = f" {str(shape[0]).rjust(5)} "
|
shape = f" {str(shape[0]).rjust(5)} "
|
||||||
if layer_id == 0 or layer_id >= args.n_layer - 1:
|
if layer_id == 0 or layer_id >= args.n_layer - 1:
|
||||||
if print_need_newline:
|
if print_need_newline:
|
||||||
prxxx("\n", end="")
|
prxxx("\n", end="")
|
||||||
@ -2108,16 +2115,25 @@ class RWKV(MyModule):
|
|||||||
state[i * 3 + 0] = torch.zeros(
|
state[i * 3 + 0] = torch.zeros(
|
||||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||||
).contiguous()
|
).contiguous()
|
||||||
state[i * 3 + 1] = torch.zeros(
|
if args.time_state:
|
||||||
(
|
state[i * 3 + 1] = (
|
||||||
args.n_head,
|
w[f"blocks.{i}.att.time_state"]
|
||||||
args.n_att // args.n_head,
|
.transpose(1, 2)
|
||||||
args.n_att // args.n_head,
|
.to(dtype=torch.float, device=dev)
|
||||||
),
|
.requires_grad_(False)
|
||||||
dtype=torch.float,
|
.contiguous()
|
||||||
requires_grad=False,
|
)
|
||||||
device=dev,
|
else:
|
||||||
).contiguous()
|
state[i * 3 + 1] = torch.zeros(
|
||||||
|
(
|
||||||
|
args.n_head,
|
||||||
|
args.n_att // args.n_head,
|
||||||
|
args.n_att // args.n_head,
|
||||||
|
),
|
||||||
|
dtype=torch.float,
|
||||||
|
requires_grad=False,
|
||||||
|
device=dev,
|
||||||
|
).contiguous()
|
||||||
state[i * 3 + 2] = torch.zeros(
|
state[i * 3 + 2] = torch.zeros(
|
||||||
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
args.n_embd, dtype=atype, requires_grad=False, device=dev
|
||||||
).contiguous()
|
).contiguous()
|
||||||
|
Loading…
Reference in New Issue
Block a user