-
Notifications
You must be signed in to change notification settings - Fork 7.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d18e9ea
commit a1f8bfc
Showing
2 changed files
with
533 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
from dataclasses import dataclass | ||
from typing import Dict | ||
from typing import Iterable, Optional | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
from torch import nn | ||
|
||
|
||
@dataclass | ||
class ModelDimensions: | ||
n_mels: int | ||
n_audio_ctx: int | ||
n_audio_state: int | ||
n_audio_head: int | ||
n_audio_layer: int | ||
n_vocab: int | ||
n_text_ctx: int | ||
n_text_state: int | ||
n_text_head: int | ||
n_text_layer: int | ||
|
||
|
||
def sinusoids(length, channels, max_timescale=10000): | ||
"""Returns sinusoids for positional embedding""" | ||
assert channels % 2 == 0 | ||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) | ||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) | ||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] | ||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) | ||
|
||
|
||
class MultiHeadAttention(nn.Module): | ||
def __init__(self, n_state: int, n_head: int): | ||
super().__init__() | ||
self.n_head = n_head | ||
self.query = nn.Linear(n_state, n_state) | ||
self.key = nn.Linear(n_state, n_state, bias=False) | ||
self.value = nn.Linear(n_state, n_state) | ||
self.out = nn.Linear(n_state, n_state) | ||
|
||
def forward( | ||
self, | ||
x: Tensor, | ||
xa: Optional[Tensor] = None, | ||
mask: Optional[Tensor] = None, | ||
kv_cache: Optional[dict] = None, | ||
): | ||
q = self.query(x) | ||
|
||
if kv_cache is None or xa is None: | ||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; | ||
# otherwise, perform key/value projections for self- or cross-attention as usual. | ||
k = self.key(x if xa is None else xa) | ||
v = self.value(x if xa is None else xa) | ||
else: | ||
# for cross-attention, calculate keys and values once and reuse in subsequent calls. | ||
k = kv_cache.get(self.key, self.key(xa)) | ||
v = kv_cache.get(self.value, self.value(xa)) | ||
|
||
wv = self.qkv_attention(q, k, v, mask) | ||
return self.out(wv) | ||
|
||
def qkv_attention( | ||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None | ||
): | ||
n_batch, n_ctx, n_state = q.shape | ||
scale = (n_state // self.n_head) ** -0.25 | ||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale | ||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale | ||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) | ||
|
||
qk = q @ k | ||
if mask is not None: | ||
qk = qk + mask[:n_ctx, :n_ctx] | ||
|
||
w = F.softmax(qk.float(), dim=-1).to(q.dtype) | ||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) | ||
|
||
|
||
class ResidualAttentionBlock(nn.Module): | ||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): | ||
super().__init__() | ||
|
||
self.attn = MultiHeadAttention(n_state, n_head) | ||
self.attn_ln = nn.LayerNorm(n_state) | ||
|
||
self.cross_attn = ( | ||
MultiHeadAttention(n_state, n_head) if cross_attention else None | ||
) | ||
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None | ||
|
||
n_mlp = n_state * 4 | ||
self.mlp = nn.Sequential( | ||
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state) | ||
) | ||
self.mlp_ln = nn.LayerNorm(n_state) | ||
|
||
def forward( | ||
self, | ||
x: Tensor, | ||
xa: Optional[Tensor] = None, | ||
mask: Optional[Tensor] = None, | ||
kv_cache: Optional[dict] = None, | ||
): | ||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache) | ||
if self.cross_attn: | ||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache) | ||
x = x + self.mlp(self.mlp_ln(x)) | ||
return x | ||
|
||
|
||
class AudioEncoder(nn.Module): | ||
def __init__( | ||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int | ||
): | ||
super().__init__() | ||
self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) | ||
self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) | ||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) | ||
|
||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( | ||
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] | ||
) | ||
self.ln_post = nn.LayerNorm(n_state) | ||
|
||
def forward(self, x: Tensor): | ||
""" | ||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) | ||
the mel spectrogram of the audio | ||
""" | ||
x = F.gelu(self.conv1(x)) | ||
x = F.gelu(self.conv2(x)) | ||
x = x.permute(0, 2, 1) | ||
|
||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" | ||
x = (x + self.positional_embedding).to(x.dtype) | ||
|
||
for block in self.blocks: | ||
x = block(x) | ||
|
||
x = self.ln_post(x) | ||
return x | ||
|
||
|
||
class TextDecoder(nn.Module): | ||
def __init__( | ||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int | ||
): | ||
super().__init__() | ||
|
||
self.token_embedding = nn.Embedding(n_vocab, n_state) | ||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) | ||
|
||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( | ||
[ | ||
ResidualAttentionBlock(n_state, n_head, cross_attention=True) | ||
for _ in range(n_layer) | ||
] | ||
) | ||
self.ln = nn.LayerNorm(n_state) | ||
|
||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) | ||
self.register_buffer("mask", mask, persistent=False) | ||
|
||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): | ||
""" | ||
x : torch.LongTensor, shape = (batch_size, <= n_ctx) | ||
the text tokens | ||
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) | ||
the encoded audio features to be attended on | ||
""" | ||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 | ||
x = ( | ||
self.token_embedding(x) | ||
+ self.positional_embedding[offset : offset + x.shape[-1]] | ||
) | ||
x = x.to(xa.dtype) | ||
|
||
for block in self.blocks: | ||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache) | ||
|
||
x = self.ln(x) | ||
logits = ( | ||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) | ||
).float() | ||
|
||
return logits | ||
|
||
|
||
class Whisper(nn.Module): | ||
def __init__(self, dims: ModelDimensions): | ||
super().__init__() | ||
self.dims = dims | ||
self.encoder = AudioEncoder( | ||
self.dims.n_mels, | ||
self.dims.n_audio_ctx, | ||
self.dims.n_audio_state, | ||
self.dims.n_audio_head, | ||
self.dims.n_audio_layer, | ||
) | ||
self.decoder = TextDecoder( | ||
self.dims.n_vocab, | ||
self.dims.n_text_ctx, | ||
self.dims.n_text_state, | ||
self.dims.n_text_head, | ||
self.dims.n_text_layer, | ||
) | ||
|
||
def embed_audio(self, mel: torch.Tensor): | ||
return self.encoder.forward(mel) | ||
|
||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): | ||
return self.decoder.forward(tokens, audio_features) | ||
|
||
def forward( | ||
self, mel: torch.Tensor, tokens: torch.Tensor | ||
) -> Dict[str, torch.Tensor]: | ||
return self.decoder(tokens, self.encoder(mel)) | ||
|
||
@property | ||
def device(self): | ||
return next(self.parameters()).device | ||
|
||
@property | ||
def is_multilingual(self): | ||
return self.dims.n_vocab == 51865 | ||
|
||
def install_kv_cache_hooks(self, cache: Optional[dict] = None): | ||
""" | ||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value | ||
tensors calculated for the previous positions. This method returns a dictionary that stores | ||
all caches, and the necessary hooks for the key and value projection modules that save the | ||
intermediate tensors to be reused during later calculations. | ||
Returns | ||
------- | ||
cache : Dict[nn.Module, torch.Tensor] | ||
A dictionary object mapping the key/value projection modules to its cache | ||
hooks : List[RemovableHandle] | ||
List of PyTorch RemovableHandle objects to stop the hooks to be called | ||
""" | ||
cache = {**cache} if cache is not None else {} | ||
hooks = [] | ||
|
||
def save_to_cache(module, _, output): | ||
if ( | ||
module not in cache | ||
or output.shape[1] > self.decoder.positional_embedding.shape[0] | ||
): | ||
cache[ | ||
module | ||
] = output # save as-is, for the first token or cross attention | ||
else: | ||
cache[module] = torch.cat([cache[module], output], dim=1).detach() | ||
return cache[module] | ||
|
||
def install_hooks(layer: nn.Module): | ||
if isinstance(layer, MultiHeadAttention): | ||
hooks.append(layer.key.register_forward_hook(save_to_cache)) | ||
hooks.append(layer.value.register_forward_hook(save_to_cache)) | ||
|
||
self.decoder.apply(install_hooks) | ||
return cache, hooks |
Oops, something went wrong.