From 8d72eb67435882fc2f4822bb82f56d4bad6641a6 Mon Sep 17 00:00:00 2001 From: Evan Arlian Date: Thu, 20 Oct 2022 11:32:48 +0700 Subject: [PATCH] fix caching --- model2.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/model2.py b/model2.py index fb87aadc..1c1f2127 100644 --- a/model2.py +++ b/model2.py @@ -88,14 +88,9 @@ def forward( # or from the encoder (this case is called cross attention) # we just need to know whether the cross attention (xa) exists or not if xa is None: - # from decoder - # TODO original code still appends to the cach but not used, why? - k = self.key(x) - v = self.value(x) - else: - # from encoder, need to manage cache - curr_k = self.key(xa) - curr_v = self.value(xa) + # this is from decoder, we need to keep appending to the cache + curr_k = self.key(x) + curr_v = self.value(x) if self.k_id in kv_cache and self.v_id in kv_cache: k = torch.cat([kv_cache[self.k_id], curr_k], dim=1) v = torch.cat([kv_cache[self.v_id], curr_v], dim=1) @@ -104,6 +99,16 @@ def forward( v = curr_v kv_cache[self.k_id] = k kv_cache[self.v_id] = v + else: + # this is from encoder and only needed to be computed ONCE per new encoded mel + if self.k_id in kv_cache and self.v_id in kv_cache: + k = kv_cache[self.k_id] + v = kv_cache[self.v_id] + else: + k = self.key(xa) + v = self.value(xa) + kv_cache[self.k_id] = k + kv_cache[self.v_id] = v wv = self.masked_qkv_attention(q, k, v, mask) return self.out(wv) @@ -249,7 +254,6 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: dict[int, Tensor]): x = self.ln(x) logits = x @ torch.transpose(self.token_embedding.weight, 0, 1) - return logits