From 0e4b6cbd15f708f2841d1ee0fe95c00a55273ac0 Mon Sep 17 00:00:00 2001 From: josc146 Date: Sun, 24 Mar 2024 15:47:17 +0800 Subject: [PATCH] make gate and out trainable (https://github.com/JL-er/RWKV-LORA/commit/834aea0f543b84eedb7cdd222672c06ca37aeacd) --- finetune/lora/v6/src/model.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/finetune/lora/v6/src/model.py b/finetune/lora/v6/src/model.py index a95b16a..dbeb0f8 100644 --- a/finetune/lora/v6/src/model.py +++ b/finetune/lora/v6/src/model.py @@ -29,7 +29,6 @@ LORA_CONFIG = { class LoraLinear(nn.Module): - def __init__(self, in_features: int, out_features: int, bias: bool): super().__init__() @@ -356,8 +355,8 @@ class RWKV_TimeMix_RWKV5(MyModule): self.key = make_linear_att(args.n_embd, args.dim_att, bias=False) self.value = make_linear_att(args.n_embd, args.dim_att, bias=False) - self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) - self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.output = make_linear_att(args.dim_att, args.n_embd, bias=False) + self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False) self.ln_x = nn.GroupNorm(self.n_head, args.dim_att) @MyFunction @@ -465,8 +464,8 @@ class RWKV_Tmix_x060(MyModule): self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False) self.key = make_linear_att(args.n_embd, args.dim_att, bias=False) self.value = make_linear_att(args.n_embd, args.dim_att, bias=False) - self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) - self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False) + self.output = make_linear_att(args.dim_att, args.n_embd, bias=False) + self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False) self.ln_x = nn.GroupNorm( self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2) )