make gate and out trainable (834aea0f54)

This commit is contained in:
josc146 2024-03-24 15:47:17 +08:00
parent 2f777f1286
commit 0e4b6cbd15

View File

@ -29,7 +29,6 @@ LORA_CONFIG = {
class LoraLinear(nn.Module): class LoraLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool): def __init__(self, in_features: int, out_features: int, bias: bool):
super().__init__() super().__init__()
@ -356,8 +355,8 @@ class RWKV_TimeMix_RWKV5(MyModule):
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False) self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False) self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) self.output = make_linear_att(args.dim_att, args.n_embd, bias=False)
self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False) self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.ln_x = nn.GroupNorm(self.n_head, args.dim_att) self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
@MyFunction @MyFunction
@ -465,8 +464,8 @@ class RWKV_Tmix_x060(MyModule):
self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False) self.receptance = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.key = make_linear_att(args.n_embd, args.dim_att, bias=False) self.key = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.value = make_linear_att(args.n_embd, args.dim_att, bias=False) self.value = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) self.output = make_linear_att(args.dim_att, args.n_embd, bias=False)
self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False) self.gate = make_linear_att(args.n_embd, args.dim_att, bias=False)
self.ln_x = nn.GroupNorm( self.ln_x = nn.GroupNorm(
self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2) self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
) )