Skip to content

Commit

Permalink
add tests (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
evanarlian committed Oct 18, 2022
1 parent 39a4257 commit bd8823f
Showing 1 changed file with 341 additions and 0 deletions.
341 changes: 341 additions & 0 deletions tests.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional, Generator\n",
"import math\n",
"import random\n",
"import itertools\n",
"from pathlib import Path\n",
"\n",
"import numpy as np\n",
"import torch\n",
"from torch import Tensor\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"import torchinfo\n",
"import whisper\n",
"import IPython.display as ipd\n",
"\n",
"from transformers import GPT2TokenizerFast\n",
"import model2\n",
"import whisper"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Original model loading"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"ori = whisper.load_model(\"tiny\", device=\"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Modified model loading"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['dims', 'model_state_dict'])"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'n_mels': 80,\n",
" 'n_vocab': 51865,\n",
" 'n_audio_ctx': 1500,\n",
" 'n_audio_state': 384,\n",
" 'n_audio_head': 6,\n",
" 'n_audio_layer': 4,\n",
" 'n_text_ctx': 448,\n",
" 'n_text_state': 384,\n",
" 'n_text_head': 6,\n",
" 'n_text_layer': 4}"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tiny_path = Path(\"~/.cache/whisper/tiny.pt\").expanduser()\n",
"with open(tiny_path, \"rb\") as f:\n",
" checkpoint = torch.load(f)\n",
"ipd.display(checkpoint.keys())\n",
"ipd.display(checkpoint[\"dims\"])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"model_dims = model2.ModelDimensions(**checkpoint[\"dims\"])\n",
"modded = model2.Whisper(model_dims).eval()\n",
"modded.load_state_dict(checkpoint[\"model_state_dict\"])\n",
"scripted = torch.jit.script(modded)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simple forward test"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"audio = whisper.load_audio(\"tests/jfk.flac\")\n",
"audio = whisper.pad_or_trim(audio)\n",
"\n",
"mel = whisper.log_mel_spectrogram(audio).unsqueeze(0)\n",
"\n",
"# # detect the spoken language\n",
"# _, probs = model.detect_language(mel)\n",
"# print(f\"Detected language: {max(probs, key=probs.get)}\")\n",
"\n",
"# # decode the audio\n",
"# options = whisper.DecodingOptions()\n",
"# result = whisper.decode(model, mel, options)\n",
"\n",
"# # print the recognized text\n",
"# print(result.text)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n"
]
}
],
"source": [
"ori_encoded = ori.encoder(mel)\n",
"modded_encoded = modded.encoder(mel)\n",
"scripted_encoded = scripted.encoder(mel)\n",
"\n",
"print(torch.allclose(ori_encoded, modded_encoded))\n",
"print(torch.allclose(ori_encoded, scripted_encoded))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# TODO add decoder greedy decoding test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Scratchpad"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class Dummy(nn.Module):\n",
" \n",
" def __init__(self, keygen: Generator) -> None:\n",
" super().__init__()\n",
" self.unique_num = next(keygen)\n",
" self.unique_num = next(keygen)\n",
" self.lin = nn.Linear(4, 4)\n",
" \n",
" def forward(self, x: Tensor, cache: dict[int, Tensor]):\n",
" if self.unique_num not in cache:\n",
" cache[self.unique_num] = self.lin(x)\n",
" return cache[self.unique_num]\n",
" \n",
" @torch.jit.export\n",
" def generate(self, x: Tensor):\n",
" print(self.unique_num)\n",
" cache: dict[int, Tensor] = {}\n",
" a = self.forward(x, cache)\n",
" b = self.forward(x*2, cache)\n",
" return a-b\n",
"\n",
"keygen = itertools.count()\n",
"dummy = Dummy(keygen)\n",
"sdummy = torch.jit.script(dummy)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
},
{
"data": {
"text/plain": [
"tensor([[0., 0., 0., 0.],\n",
" [0., 0., 0., 0.],\n",
" [0., 0., 0., 0.]], grad_fn=<SubBackward0>)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dummy.generate(torch.randn(3, 4))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
},
{
"data": {
"text/plain": [
"tensor([[0., 0., 0., 0.],\n",
" [0., 0., 0., 0.],\n",
" [0., 0., 0., 0.]], grad_fn=<SubBackward0>)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sdummy.generate(torch.randn(3, 4))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 ('ml')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "c788065d0627783e03f01588c616fdc081ccf79059243dc851e48ed3fc07eef9"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit bd8823f

Please sign in to comment.