Skip to content

Commit

Permalink
remove data type conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
evanarlian committed Oct 19, 2022
1 parent bd8823f commit 9b411f0
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@
from dataclasses import dataclass
from typing import Optional, Generator


import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor


# TODO remove float(), another data types call to(x.dtype)


@dataclass
class ModelDimensions:
n_mels: int
Expand Down Expand Up @@ -62,7 +58,7 @@ def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor])
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
w = F.softmax(qk, dim=-1)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)


Expand Down Expand Up @@ -105,7 +101,7 @@ def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor])
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
w = F.softmax(qk, dim=-1)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)


Expand Down Expand Up @@ -178,7 +174,7 @@ def forward(self, x: Tensor):
x = x.permute(0, 2, 1)

assert x[0].size() == self.positional_embedding.size(), "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
x = x + self.positional_embedding

for block in self.blocks:
x = block(x)
Expand Down Expand Up @@ -225,15 +221,12 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: dict[int, Tensor]):
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.size(-1)]
)
x = x.to(xa.dtype)

for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)

x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
logits = x @ torch.transpose(self.token_embedding.weight, 0, 1)

return logits

Expand Down

0 comments on commit 9b411f0

Please sign in to comment.