Skip to content

Commit

Permalink
separate modules, make scriptable
Browse files Browse the repository at this point in the history
  • Loading branch information
evanarlian committed Oct 18, 2022
1 parent d2ec7f2 commit fd981dc
Showing 1 changed file with 83 additions and 106 deletions.
189 changes: 83 additions & 106 deletions model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from torch import nn


# TODO remove float(), another data types call to(x.dtype)
# TODO list kv cache fix


@dataclass
class ModelDimensions:
n_mels: int
Expand Down Expand Up @@ -41,73 +45,102 @@ def __init__(self, n_state: int, n_head: int):
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)

def forward(self, x: Tensor, mask: Optional[Tensor]):
# multi head attention used in the encoder
q = self.query(x)
k = self.key(x)
v = self.value(x)
wv = self.qkv_attention(q, k, v, mask)
return self.out(wv)

def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor]):
n_batch, n_ctx, n_state = q.size()
scale = (n_state // self.n_head) ** -0.25
q = q.view(q.size(0), q.size(1), self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(q.size(0), q.size(1), self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(q.size(0), q.size(1), self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)


class MultiHeadCrossAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = nn.Linear(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)

def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
self, x: Tensor, xa: Tensor, mask: Optional[Tensor], kv_cache: list[Tensor]
):
# multi head attention used in the decoder
q = self.query(x)

if kv_cache is None or xa is None:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
if kv_cache:
k, v = kv_cache
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache.get(self.key, self.key(xa))
v = kv_cache.get(self.value, self.value(xa))

k = self.key(xa)
v = self.value(xa)
# TODO append something to the kv cache and manage return TODO FIXME
wv = self.qkv_attention(q, k, v, mask)
return self.out(wv)

def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
):
n_batch, n_ctx, n_state = q.shape
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor]):
n_batch, n_ctx, n_state = q.size()
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

q = q.view(q.size(0), q.size(1), self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(q.size(0), q.size(1), self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(q.size(0), q.size(1), self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]

w = F.softmax(qk.float(), dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)


class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
def __init__(self, n_state: int, n_head: int):
super().__init__()

self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = nn.LayerNorm(n_state)

self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
nn.Linear(n_state, n_mlp),
nn.GELU(),
nn.Linear(n_mlp, n_state),
)
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
self.mlp_ln = nn.LayerNorm(n_state)

def forward(self, x: Tensor, mask: Optional[Tensor] = None):
# standard encoder attention block with skip connection
x = x + self.attn(self.attn_ln(x), mask=mask)
x = x + self.mlp(self.mlp_ln(x))
return x


class ResidualCrossAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = MultiHeadCrossAttention(n_state, n_head)
self.cross_attn_ln = nn.LayerNorm(n_state)
n_mlp = n_state * 4
self.mlp = nn.Sequential(
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
nn.Linear(n_state, n_mlp),
nn.GELU(),
nn.Linear(n_mlp, n_state),
)
self.mlp_ln = nn.LayerNorm(n_state)

def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
def forward(self, x: Tensor, xa: Tensor, kv_cache: list[Tensor], mask: Tensor = None):
# decoder attn and cross-attn block with skip connection
x = x + self.attn(self.attn_ln(x), mask=mask)
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, mask=mask)
x = x + self.mlp(self.mlp_ln(x))
return x

Expand All @@ -121,7 +154,7 @@ def __init__(
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
self.blocks = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = nn.LayerNorm(n_state)
Expand All @@ -135,7 +168,7 @@ def forward(self, x: Tensor):
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)

assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
assert x[0].size() == self.positional_embedding.size(), "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)

for block in self.blocks:
Expand All @@ -154,28 +187,25 @@ def __init__(
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
self.blocks = nn.ModuleList(
[ResidualCrossAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln = nn.LayerNorm(n_state)

mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)

def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
def forward(self, x: Tensor, xa: Tensor, kv_cache: list[Tensor]):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
the encoded audio features to be attended on
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
offset = kv_cache[0][0].size(1) if kv_cache != [] else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
+ self.positional_embedding[offset : offset + x.size(-1)]
)
x = x.to(xa.dtype)

Expand Down Expand Up @@ -209,58 +239,5 @@ def __init__(self, dims: ModelDimensions):
self.dims.n_text_layer,
)

def embed_audio(self, mel: torch.Tensor):
return self.encoder.forward(mel)

def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder.forward(tokens, audio_features)

def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))

@property
def device(self):
return next(self.parameters()).device

@property
def is_multilingual(self):
return self.dims.n_vocab == 51865

def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []

def save_to_cache(module, _, output):
if (
module not in cache
or output.shape[1] > self.decoder.positional_embedding.shape[0]
):
cache[
module
] = output # save as-is, for the first token or cross attention
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]

def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))

self.decoder.apply(install_hooks)
return cache, hooks
def forward(self, mel: Tensor, tokens: Tensor):
return self.decoder(tokens, self.encoder(mel), kv_cache=[])

0 comments on commit fd981dc

Please sign in to comment.