Skip to content

Commit

Permalink
fix caching
Browse files Browse the repository at this point in the history
  • Loading branch information
evanarlian committed Oct 20, 2022
1 parent d4689b9 commit 8d72eb6
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,9 @@ def forward(
# or from the encoder (this case is called cross attention)
# we just need to know whether the cross attention (xa) exists or not
if xa is None:
# from decoder
# TODO original code still appends to the cach but not used, why?
k = self.key(x)
v = self.value(x)
else:
# from encoder, need to manage cache
curr_k = self.key(xa)
curr_v = self.value(xa)
# this is from decoder, we need to keep appending to the cache
curr_k = self.key(x)
curr_v = self.value(x)
if self.k_id in kv_cache and self.v_id in kv_cache:
k = torch.cat([kv_cache[self.k_id], curr_k], dim=1)
v = torch.cat([kv_cache[self.v_id], curr_v], dim=1)
Expand All @@ -104,6 +99,16 @@ def forward(
v = curr_v
kv_cache[self.k_id] = k
kv_cache[self.v_id] = v
else:
# this is from encoder and only needed to be computed ONCE per new encoded mel
if self.k_id in kv_cache and self.v_id in kv_cache:
k = kv_cache[self.k_id]
v = kv_cache[self.v_id]
else:
k = self.key(xa)
v = self.value(xa)
kv_cache[self.k_id] = k
kv_cache[self.v_id] = v

wv = self.masked_qkv_attention(q, k, v, mask)
return self.out(wv)
Expand Down Expand Up @@ -249,7 +254,6 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: dict[int, Tensor]):

x = self.ln(x)
logits = x @ torch.transpose(self.token_embedding.weight, 0, 1)

return logits


Expand Down

0 comments on commit 8d72eb6

Please sign in to comment.