Skip to content

Commit

Permalink
fix layers
Browse files Browse the repository at this point in the history
  • Loading branch information
evanarlian committed Oct 19, 2022
1 parent 9b411f0 commit ad9eb03
Showing 1 changed file with 48 additions and 26 deletions.
74 changes: 48 additions & 26 deletions model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from torch import Tensor


# TODO detach from the cache??


@dataclass
class ModelDimensions:
n_mels: int
Expand Down Expand Up @@ -41,28 +44,26 @@ def __init__(self, n_state: int, n_head: int):
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)

def forward(self, x: Tensor, mask: Optional[Tensor]):
def forward(self, x: Tensor):
# multi head attention used in the encoder
q = self.query(x)
k = self.key(x)
v = self.value(x)
wv = self.qkv_attention(q, k, v, mask)
wv = self.qkv_attention(q, k, v)
return self.out(wv)

def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor]):
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor):
n_batch, n_ctx, n_state = q.size()
scale = (n_state // self.n_head) ** -0.25
q = q.view(q.size(0), q.size(1), self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(k.size(0), k.size(1), self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(v.size(0), v.size(1), self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
w = F.softmax(qk, dim=-1)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)


class MultiHeadCrossAttention(nn.Module):
class CachedMultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int, keygen: Generator):
super().__init__()
self.k_id = next(keygen)
Expand All @@ -74,25 +75,42 @@ def __init__(self, n_state: int, n_head: int, keygen: Generator):
self.out = nn.Linear(n_state, n_state)

def forward(
self, x: Tensor, xa: Tensor, mask: Optional[Tensor], kv_cache: dict[int, Tensor]
self,
x: Tensor,
kv_cache: dict[int, Tensor],
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
):
# multi head attention used in the decoder
# q will always come from the bottom (from previous decoder)
q = self.query(x)
k = self.key(xa)
v = self.value(xa)

# managing cache
if self.k_id in kv_cache and self.v_id in kv_cache:
kv_cache[self.k_id] = torch.cat([kv_cache[self.k_id], k])
kv_cache[self.v_id] = torch.cat([kv_cache[self.v_id], v])
# k and v in can be from the bottom (from previous decoder)
# or from the encoder (this case is called cross attention)
# we just need to know whether the cross attention (xa) exists or not
if xa is None:
# from decoder
# TODO original code still appends to the cach but not used, why?
k = self.key(x)
v = self.value(x)
else:
# from encoder, need to manage cache
curr_k = self.key(xa)
curr_v = self.value(xa)
if self.k_id in kv_cache and self.v_id in kv_cache:
k = torch.cat([kv_cache[self.k_id], curr_k], dim=1)
v = torch.cat([kv_cache[self.v_id], curr_v], dim=1)
else:
k = curr_k
v = curr_v
kv_cache[self.k_id] = k
kv_cache[self.v_id] = v

wv = self.qkv_attention(q, k, v, mask)
wv = self.masked_qkv_attention(q, k, v, mask)
return self.out(wv)

def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor]):
def masked_qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
):
n_batch, n_ctx, n_state = q.size()
scale = (n_state // self.n_head) ** -0.25
q = q.view(q.size(0), q.size(1), self.n_head, -1).permute(0, 2, 1, 3) * scale
Expand All @@ -118,19 +136,19 @@ def __init__(self, n_state: int, n_head: int):
)
self.mlp_ln = nn.LayerNorm(n_state)

def forward(self, x: Tensor, mask: Optional[Tensor] = None):
def forward(self, x: Tensor):
# standard encoder attention block with skip connection
x = x + self.attn(self.attn_ln(x), mask=mask)
x = x + self.attn(self.attn_ln(x))
x = x + self.mlp(self.mlp_ln(x))
return x


class ResidualCrossAttentionBlock(nn.Module):
class CachedResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, keygen: Generator):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn = CachedMultiHeadAttention(n_state, n_head, keygen)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = MultiHeadCrossAttention(n_state, n_head, keygen)
self.cross_attn = CachedMultiHeadAttention(n_state, n_head, keygen)
self.cross_attn_ln = nn.LayerNorm(n_state)
n_mlp = n_state * 4
self.mlp = nn.Sequential(
Expand All @@ -141,11 +159,15 @@ def __init__(self, n_state: int, n_head: int, keygen: Generator):
self.mlp_ln = nn.LayerNorm(n_state)

def forward(
self, x: Tensor, xa: Tensor, kv_cache: dict[int, Tensor], mask: Tensor = None
self,
x: Tensor,
kv_cache: dict[int, Tensor],
xa: Tensor,
mask: Tensor,
):
# decoder attn and cross-attn block with skip connection
x = x + self.attn(self.attn_ln(x), mask=mask)
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, mask=mask)
x = x + self.attn(self.attn_ln(x), kv_cache, mask=mask)
x = x + self.cross_attn(self.cross_attn_ln(x), kv_cache, xa=xa)
x = x + self.mlp(self.mlp_ln(x))
return x

Expand Down Expand Up @@ -200,7 +222,7 @@ def __init__(

self.blocks = nn.ModuleList(
[
ResidualCrossAttentionBlock(n_state, n_head, keygen)
CachedResidualAttentionBlock(n_state, n_head, keygen)
for _ in range(n_layer)
]
)
Expand All @@ -223,7 +245,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: dict[int, Tensor]):
)

for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = block(x, kv_cache, xa, self.mask)

x = self.ln(x)
logits = x @ torch.transpose(self.token_embedding.weight, 0, 1)
Expand Down

0 comments on commit ad9eb03

Please sign in to comment.