Skip to content

Commit

Permalink
Fix attention caching to make it actually work (openai#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
vickianand committed Oct 19, 2022
1 parent 7f3e408 commit 9f70a35
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions whisper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ def forward(
):
q = self.query(x)

if kv_cache is None or xa is None:
if kv_cache is None or xa is None or self.key not in kv_cache:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache.get(self.key, self.key(xa))
v = kv_cache.get(self.value, self.value(xa))
k = kv_cache[self.key]
v = kv_cache[self.value]

wv = self.qkv_attention(q, k, v, mask)
return self.out(wv)
Expand Down

0 comments on commit 9f70a35

Please sign in to comment.