Skip to content

Commit

Permalink
Revert "fix alibi inference shapes for cached layer_past" (EleutherAI…
Browse files Browse the repository at this point in the history
  • Loading branch information
ShivanshuPurohit committed Nov 2, 2021
1 parent c638323 commit 19b1683
Showing 1 changed file with 4 additions and 15 deletions.
19 changes: 4 additions & 15 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,25 +102,14 @@ def get_slopes_power_of_2(n):

def forward(self, x):
# [b, np, sq, sk]
seq_len_q = x.shape[-2]
seq_len_k = x.shape[-1]
if self.cached_seq_len != seq_len_k:
a = -torch.tril(torch.arange(seq_len_k).view(seq_len_k, 1).repeat(1, seq_len_k) + torch.arange(0, -seq_len_k, -1))
seq_len = x.shape[-1]
if self.cached_seq_len != seq_len:
a = -torch.tril(torch.arange(seq_len).view(seq_len, 1).repeat(1, seq_len) + torch.arange(0, -seq_len, -1))
a = a.to(x.device).to(x.dtype)
slopes = self.slopes.to(a.device).to(a.dtype)
a = a * slopes.view(self.slopes.shape[0], 1, 1)
self.cached_seq_len = seq_len_k
self.cached_seq_len = seq_len
self.cached_matrix = a
else:
a = self.cached_matrix

if seq_len_q != seq_len_k:
# In the train case x has dimensionality [b, np, sq, sk] with sq == sk
# The number of query tokens is equal to the number of key tokens
# At inference time with cache in layer_past sq is not equal to sk. sq only contains one token (the last one in the full sequence)
# In this case we use the appropriate token index of the cache matrix.
# As the cache matrix could already be bigger from a past inference, not the last token index in the sq sequence is used
assert seq_len_q == 1, "assumption sq == sk unless at inference time with cache in layer_past with sq == 1"
a = a[:, seq_len_k - 1, :].view(a.shape[0], 1, a.shape[2]) # seq_len_k - 1 points to the last token index in the current inference batch.

return x + a

0 comments on commit 19b1683

Please sign in to comment.