Skip to content

Commit

Permalink
better tinyAtt
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Nov 28, 2022
1 parent de8bae7 commit a268cd2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
20 changes: 11 additions & 9 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,11 @@ def __init__(self, args, layer_id):
self.ffn = RWKV_ChannelMix(args, layer_id)

if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
self.head_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.head_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.head_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.register_buffer("head_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))
self.tiny_ln = nn.LayerNorm(args.n_embd)
self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False)
self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)))

def forward(self, x, x_emb=None):
args = self.args
Expand All @@ -255,11 +256,12 @@ def forward(self, x, x_emb=None):
x = x + self.ffn(self.ln2(x))

if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer:
q = self.head_q(x)[:, :T, :]
k = self.head_k(x)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (1.0 / args.tiny_att_downscale)
c = c.masked_fill(self.head_mask[:T, :T] == 0, 0)
x = x + c @ self.head_v(x_emb)
xx = self.tiny_ln(x)
q = self.tiny_q(xx)[:, :T, :]
k = self.tiny_k(xx)[:, :T, :]
c = (q @ k.transpose(-2, -1)) * (args.tiny_att_dim ** (-0.5))
c = c.masked_fill(self.tiny_mask[:T, :T] == 0, 0)
x = x + c @ self.tiny_v(x_emb)
return x


Expand Down
1 change: 0 additions & 1 deletion RWKV-v4neo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
parser.add_argument("--head_qk", default=0, type=int) # my headQK trick
parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim
parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer
parser.add_argument("--tiny_att_downscale", default=0, type=float)

parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
parser.add_argument("--lr_final", default=1e-5, type=float)
Expand Down

0 comments on commit a268cd2

Please sign in to comment.