Skip to content

Commit

Permalink
add decoding test
Browse files Browse the repository at this point in the history
  • Loading branch information
evanarlian committed Oct 19, 2022
1 parent ad9eb03 commit d4689b9
Showing 1 changed file with 97 additions and 34 deletions.
131 changes: 97 additions & 34 deletions tests.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Original model loading"
"# Model loading"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"ori = whisper.load_model(\"tiny\", device=\"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Modified model loading"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
Expand Down Expand Up @@ -99,26 +83,27 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"ori = whisper.load_model(\"tiny\", device=\"cpu\").eval() # original model loading\n",
"model_dims = model2.ModelDimensions(**checkpoint[\"dims\"])\n",
"modded = model2.Whisper(model_dims).eval()\n",
"modded.load_state_dict(checkpoint[\"model_state_dict\"])\n",
"scripted = torch.jit.script(modded)"
"scripted = torch.jit.script(modded).eval()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simple forward test"
"# Simple encoding test"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -141,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -162,13 +147,56 @@
"print(torch.allclose(ori_encoded, scripted_encoded))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simple decoding test"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[50258, 50259, 50359, 50363]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# this is <|startoftranscript|><|en|><|transcribe|><|notimestamps|> from gpt2 tokenizer\n",
"tokens = torch.tensor([50258, 50259, 50359, 50363]).unsqueeze(0)\n",
"tokens"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n"
]
}
],
"source": [
"# TODO add decoder greedy decoding test"
"ori_decoded = ori.decoder(tokens, ori_encoded)\n",
"modded_decoded = modded.decoder(tokens, ori_encoded, {})\n",
"scripted_decoded = scripted.decoder(tokens, ori_encoded, {})\n",
"\n",
"print(torch.allclose(ori_decoded, modded_decoded))\n",
"print(torch.allclose(ori_decoded, scripted_decoded))"
]
},
{
Expand All @@ -185,6 +213,20 @@
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -298,17 +340,38 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"ename": "TypeError",
"evalue": "cannot assign 'bool' as child module 'training' (torch.nn.Module or None expected)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/evan/Documents/whisper-torchscript/tests.ipynb Cell 22\u001b[0m in \u001b[0;36m<cell line: 12>\u001b[0;34m()\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/evan/Documents/whisper-torchscript/tests.ipynb#X25sZmlsZQ%3D%3D?line=7'>8</a>\u001b[0m \u001b[39mreturn\u001b[39;00m x\n\u001b[1;32m <a href='vscode-notebook-cell:/home/evan/Documents/whisper-torchscript/tests.ipynb#X25sZmlsZQ%3D%3D?line=10'>11</a>\u001b[0m net \u001b[39m=\u001b[39m Net()\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/evan/Documents/whisper-torchscript/tests.ipynb#X25sZmlsZQ%3D%3D?line=11'>12</a>\u001b[0m net\u001b[39m.\u001b[39;49meval()\n",
"File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1858\u001b[0m, in \u001b[0;36mModule.eval\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1842\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39meval\u001b[39m(\u001b[39mself\u001b[39m: T) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m T:\n\u001b[1;32m 1843\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"Sets the module in evaluation mode.\u001b[39;00m\n\u001b[1;32m 1844\u001b[0m \n\u001b[1;32m 1845\u001b[0m \u001b[39m This has any effect only on certain modules. See documentations of\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1856\u001b[0m \u001b[39m Module: self\u001b[39;00m\n\u001b[1;32m 1857\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1858\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtrain(\u001b[39mFalse\u001b[39;49;00m)\n",
"File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1837\u001b[0m, in \u001b[0;36mModule.train\u001b[0;34m(self, mode)\u001b[0m\n\u001b[1;32m 1835\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(mode, \u001b[39mbool\u001b[39m):\n\u001b[1;32m 1836\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mtraining mode is expected to be boolean\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m-> 1837\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining \u001b[39m=\u001b[39m mode\n\u001b[1;32m 1838\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mchildren():\n\u001b[1;32m 1839\u001b[0m module\u001b[39m.\u001b[39mtrain(mode)\n",
"File \u001b[0;32m~/miniconda3/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py:1242\u001b[0m, in \u001b[0;36mModule.__setattr__\u001b[0;34m(self, name, value)\u001b[0m\n\u001b[1;32m 1240\u001b[0m \u001b[39melif\u001b[39;00m modules \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m name \u001b[39min\u001b[39;00m modules:\n\u001b[1;32m 1241\u001b[0m \u001b[39mif\u001b[39;00m value \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m-> 1242\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mcannot assign \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m as child module \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1243\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m(torch.nn.Module or None expected)\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1244\u001b[0m \u001b[39m.\u001b[39mformat(torch\u001b[39m.\u001b[39mtypename(value), name))\n\u001b[1;32m 1245\u001b[0m modules[name] \u001b[39m=\u001b[39m value\n\u001b[1;32m 1246\u001b[0m \u001b[39melse\u001b[39;00m:\n",
"\u001b[0;31mTypeError\u001b[0m: cannot assign 'bool' as child module 'training' (torch.nn.Module or None expected)"
]
}
],
"source": [
"class Net(nn.Module):\n",
"\n",
" def __init__(self) -> None:\n",
" super().__init__()\n",
" self.training = nn.Linear(5, 4) # hmm\n",
"\n",
" def forward(self, x):\n",
" return x\n",
"\n",
"\n",
"net = Net()\n",
"net.eval()"
]
}
],
"metadata": {
Expand Down

0 comments on commit d4689b9

Please sign in to comment.