diff --git a/model2.py b/model2.py index 55122e07f..74cce4270 100644 --- a/model2.py +++ b/model2.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Optional, Generator - import numpy as np import torch from torch import nn @@ -10,9 +9,6 @@ from torch import Tensor -# TODO remove float(), another data types call to(x.dtype) - - @dataclass class ModelDimensions: n_mels: int @@ -62,7 +58,7 @@ def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor]) qk = q @ k if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] - w = F.softmax(qk.float(), dim=-1).to(q.dtype) + w = F.softmax(qk, dim=-1) return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) @@ -105,7 +101,7 @@ def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor]) qk = q @ k if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] - w = F.softmax(qk.float(), dim=-1).to(q.dtype) + w = F.softmax(qk, dim=-1) return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2) @@ -178,7 +174,7 @@ def forward(self, x: Tensor): x = x.permute(0, 2, 1) assert x[0].size() == self.positional_embedding.size(), "incorrect audio shape" - x = (x + self.positional_embedding).to(x.dtype) + x = x + self.positional_embedding for block in self.blocks: x = block(x) @@ -225,15 +221,12 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: dict[int, Tensor]): self.token_embedding(x) + self.positional_embedding[offset : offset + x.size(-1)] ) - x = x.to(xa.dtype) for block in self.blocks: x = block(x, xa, mask=self.mask, kv_cache=kv_cache) x = self.ln(x) - logits = ( - x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - ).float() + logits = x @ torch.transpose(self.token_embedding.weight, 0, 1) return logits