diff --git a/backend-python/requirements.txt b/backend-python/requirements.txt index 5acb20c..1d19926 100644 --- a/backend-python/requirements.txt +++ b/backend-python/requirements.txt @@ -1,7 +1,7 @@ torch torchvision torchaudio -rwkv==0.8.25 +rwkv==0.8.26 langchain==0.0.322 fastapi==0.109.1 uvicorn==0.23.2 diff --git a/backend-python/requirements_without_cyac.txt b/backend-python/requirements_without_cyac.txt index c9423fe..b23a384 100644 --- a/backend-python/requirements_without_cyac.txt +++ b/backend-python/requirements_without_cyac.txt @@ -1,7 +1,7 @@ torch torchvision torchaudio -rwkv==0.8.25 +rwkv==0.8.26 langchain==0.0.322 fastapi==0.109.1 uvicorn==0.23.2 diff --git a/backend-python/rwkv_pip/model.py b/backend-python/rwkv_pip/model.py index 9f495d5..570f77c 100644 --- a/backend-python/rwkv_pip/model.py +++ b/backend-python/rwkv_pip/model.py @@ -488,14 +488,19 @@ class RWKV(MyModule): print_need_newline = False REAL_TIME_FIRST = False + args.time_state = False for x in list(w.keys()): if ".time_faaaa" in x: REAL_TIME_FIRST = True + if ".time_state" in x: + args.time_state = True if REAL_TIME_FIRST: w = { - k.replace(".time_faaaa", ".time_first") - if ".time_faaaa" in k - else k: v + ( + k.replace(".time_faaaa", ".time_first") + if ".time_faaaa" in k + else k + ): v for k, v in w.items() } self.w = w @@ -631,10 +636,12 @@ class RWKV(MyModule): torch.cuda.empty_cache() shape = [i for i in w[x].shape if i != 1] - if len(shape) > 1: - shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" + if len(shape) > 2: + shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}" + elif len(shape) > 1: + shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} " else: - shape = f" {str(shape[0]).rjust(5)} " + shape = f" {str(shape[0]).rjust(5)} " if layer_id == 0 or layer_id >= args.n_layer - 1: if print_need_newline: prxxx("\n", end="") @@ -2108,16 +2115,25 @@ class RWKV(MyModule): state[i * 3 + 0] = torch.zeros( args.n_embd, dtype=atype, requires_grad=False, device=dev ).contiguous() - state[i * 3 + 1] = torch.zeros( - ( - args.n_head, - args.n_att // args.n_head, - args.n_att // args.n_head, - ), - dtype=torch.float, - requires_grad=False, - device=dev, - ).contiguous() + if args.time_state: + state[i * 3 + 1] = ( + w[f"blocks.{i}.att.time_state"] + .transpose(1, 2) + .to(dtype=torch.float, device=dev) + .requires_grad_(False) + .contiguous() + ) + else: + state[i * 3 + 1] = torch.zeros( + ( + args.n_head, + args.n_att // args.n_head, + args.n_att // args.n_head, + ), + dtype=torch.float, + requires_grad=False, + device=dev, + ).contiguous() state[i * 3 + 2] = torch.zeros( args.n_embd, dtype=atype, requires_grad=False, device=dev ).contiguous()