Skip to content

Commit

Permalink
remove hooks, use standard passing
Browse files Browse the repository at this point in the history
  • Loading branch information
evanarlian committed Oct 18, 2022
1 parent fd981dc commit 797e1a1
Showing 1 changed file with 48 additions and 19 deletions.
67 changes: 48 additions & 19 deletions model2.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import itertools
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
from typing import Optional, Generator


import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor
from torch import nn


# TODO remove float(), another data types call to(x.dtype)
# TODO list kv cache fix


@dataclass
Expand All @@ -36,6 +36,7 @@ def sinusoids(length, channels, max_timescale=10000):
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)


# HACK this is for encoder, no need for cache nor keygen
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
Expand Down Expand Up @@ -66,26 +67,34 @@ def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor])
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)


# HACK this is for decoder, need keygen, cache, and xa
class MultiHeadCrossAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
def __init__(self, n_state: int, n_head: int, keygen: Generator):
super().__init__()
self.k_id = next(keygen)
self.v_id = next(keygen)
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)

def forward(
self, x: Tensor, xa: Tensor, mask: Optional[Tensor], kv_cache: list[Tensor]
self, x: Tensor, xa: Tensor, mask: Optional[Tensor], kv_cache: dict[int, Tensor]
):
# multi head attention used in the decoder
q = self.query(x)
if kv_cache:
k, v = kv_cache
k = self.key(xa)
v = self.value(xa)

# managing cache
if self.k_id in kv_cache and self.v_id in kv_cache:
kv_cache[self.k_id] = torch.cat([kv_cache[self.k_id], k])
kv_cache[self.v_id] = torch.cat([kv_cache[self.v_id], v])
else:
k = self.key(xa)
v = self.value(xa)
# TODO append something to the kv cache and manage return TODO FIXME
kv_cache[self.k_id] = k
kv_cache[self.v_id] = v

wv = self.qkv_attention(q, k, v, mask)
return self.out(wv)

Expand Down Expand Up @@ -123,11 +132,11 @@ def forward(self, x: Tensor, mask: Optional[Tensor] = None):


class ResidualCrossAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int):
def __init__(self, n_state: int, n_head: int, keygen: Generator):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = MultiHeadCrossAttention(n_state, n_head)
self.cross_attn = MultiHeadCrossAttention(n_state, n_head, keygen)
self.cross_attn_ln = nn.LayerNorm(n_state)
n_mlp = n_state * 4
self.mlp = nn.Sequential(
Expand All @@ -137,7 +146,9 @@ def __init__(self, n_state: int, n_head: int):
)
self.mlp_ln = nn.LayerNorm(n_state)

def forward(self, x: Tensor, xa: Tensor, kv_cache: list[Tensor], mask: Tensor = None):
def forward(
self, x: Tensor, xa: Tensor, kv_cache: dict[int, Tensor], mask: Tensor = None
):
# decoder attn and cross-attn block with skip connection
x = x + self.attn(self.attn_ln(x), mask=mask)
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, mask=mask)
Expand Down Expand Up @@ -180,29 +191,38 @@ def forward(self, x: Tensor):

class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
self,
n_vocab: int,
n_ctx: int,
n_state: int,
n_head: int,
n_layer: int,
keygen: Generator,
):
super().__init__()

self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

self.blocks = nn.ModuleList(
[ResidualCrossAttentionBlock(n_state, n_head) for _ in range(n_layer)]
[
ResidualCrossAttentionBlock(n_state, n_head, keygen)
for _ in range(n_layer)
]
)
self.ln = nn.LayerNorm(n_state)

mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)

def forward(self, x: Tensor, xa: Tensor, kv_cache: list[Tensor]):
def forward(self, x: Tensor, xa: Tensor, kv_cache: dict[int, Tensor]):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
the encoded audio features to be attended on
"""
offset = kv_cache[0][0].size(1) if kv_cache != [] else 0
offset = kv_cache[0].size(1) if len(kv_cache) > 0 else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.size(-1)]
Expand All @@ -224,6 +244,7 @@ class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
self.dims = dims
self.keygen = itertools.count()
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
Expand All @@ -237,7 +258,15 @@ def __init__(self, dims: ModelDimensions):
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
self.keygen,
)

def forward(self, mel: Tensor, tokens: Tensor):
return self.decoder(tokens, self.encoder(mel), kv_cache=[])
cache: dict[int, Tensor] = {}
encoded = self.encoder(mel)
logits = self.decoder(tokens, encoded, kv_cache=cache)
return logits

def greedy_decode(self):
# TODO export, cache should be from here.
pass

0 comments on commit 797e1a1

Please sign in to comment.