From 3ce85e539f974189380802eec151fb822684a8dc Mon Sep 17 00:00:00 2001 From: Dheeraj Mohan Date: Thu, 15 Aug 2019 22:27:37 +0530 Subject: [PATCH] Feat: Implement Dropout in Gluon Changes to be committed: modified: DropOutScratch.ipynb new file: DropoutGluon.ipynb --- DropOutScratch.ipynb | 23 +++--- DropoutGluon.ipynb | 183 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+), 9 deletions(-) create mode 100644 DropoutGluon.ipynb diff --git a/DropOutScratch.ipynb b/DropOutScratch.ipynb index 6defc37..79fc607 100644 --- a/DropOutScratch.ipynb +++ b/DropOutScratch.ipynb @@ -143,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -152,30 +152,35 @@ "num_inputs = 28*28\n", "num_hidden1, num_hidden2 = 256, 256\n", "num_output = 10\n", - "drop_probs = [0.0, 0.0]\n", + "drop_probs = [0.2, 0.5]\n", "\n", "params = init_params(num_inputs, num_hidden1, num_hidden2, num_output)" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0, acc: 0.798250 loss: 0.545852\n", - "Epoch 1, acc: 0.827333 loss: 0.483884\n", - "Epoch 2, acc: 0.840633 loss: 0.448361\n", - "Epoch 3, acc: 0.854817 loss: 0.406807\n", - "Epoch 4, acc: 0.861367 loss: 0.389544\n" + "Epoch 0, acc: 0.843933 loss: 0.449781\n", + "Epoch 1, acc: 0.849133 loss: 0.423806\n", + "Epoch 2, acc: 0.853700 loss: 0.407548\n", + "Epoch 3, acc: 0.857400 loss: 0.393078\n", + "Epoch 4, acc: 0.865700 loss: 0.377074\n", + "Epoch 5, acc: 0.869967 loss: 0.364002\n", + "Epoch 6, acc: 0.872833 loss: 0.356155\n", + "Epoch 7, acc: 0.871917 loss: 0.355115\n", + "Epoch 8, acc: 0.874833 loss: 0.346407\n", + "Epoch 9, acc: 0.881317 loss: 0.328503\n" ] } ], "source": [ - "epochs = 5\n", + "epochs = 10\n", "lr = 0.1\n", "loss = gluon.loss.SoftmaxCrossEntropyLoss()\n", "for epoch in range(epochs):\n", diff --git a/DropoutGluon.ipynb b/DropoutGluon.ipynb new file mode 100644 index 0000000..b950322 --- /dev/null +++ b/DropoutGluon.ipynb @@ -0,0 +1,183 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "from mxnet import nd, autograd, gluon, init\n", + "from mxnet.gluon import nn\n", + "from mxnet.gluon.data.vision import transforms" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def get_dataloader(batch_size):\n", + " transformer = transforms.Compose([\n", + " transforms.ToTensor()\n", + " ])\n", + " train = gluon.data.vision.datasets.FashionMNIST(train=True)\n", + " train = train.transform_first(transformer)\n", + " train_iter = gluon.data.DataLoader(train, batch_size, shuffle=True, num_workers=4)\n", + " test = gluon.data.vision.datasets.FashionMNIST(train=False)\n", + " test = test.transform_first(transformer)\n", + " test_iter = gluon.data.DataLoader(test, batch_size, shuffle=False, num_workers=4)\n", + " return train_iter, test_iter" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def show_images(X, nrows, ncols):\n", + " _, axes = plt.subplots(nrows, ncols)\n", + " axes = axes.flatten()\n", + " for ax, img in zip(axes, X):\n", + " ax.imshow(img)\n", + " ax.get_xaxis().set_visible(False)\n", + " ax.get_yaxis().set_visible(False)\n", + " return axes" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "def softmax(X):\n", + " X_exp = X.exp()\n", + " normalization = X_exp.sum(axis=1, keepdims=True)\n", + " return X_exp / normalization" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "def accuracy(net, data_iter):\n", + " acc = 0\n", + " size = 0\n", + " for X, y in data_iter:\n", + " result = net(X)\n", + " y_hat = softmax(result)\n", + " acc += (y_hat.argmax(axis=1) == y.astype('float32')).sum().asscalar()\n", + " size += len(y)\n", + " return acc / size" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 256\n", + "train_iter, test_iter = get_dataloader(batch_size)\n", + "\n", + "net = nn.Sequential()\n", + "net.add(nn.Dense(256))\n", + "net.add(nn.Dropout(0.2))\n", + "net.add(nn.Dense(256))\n", + "net.add(nn.Dropout(0.5))\n", + "net.add(nn.Dense(10))\n", + "\n", + "net.initialize(init.Normal(0.01))\n", + "\n", + "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':0.1})\n", + "\n", + "loss = gluon.loss.SoftmaxCrossEntropyLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0, acc: 0.665200\n", + "Epoch 1, acc: 0.773100\n", + "Epoch 2, acc: 0.803400\n", + "Epoch 3, acc: 0.819400\n", + "Epoch 4, acc: 0.829000\n", + "Epoch 5, acc: 0.828800\n", + "Epoch 6, acc: 0.833100\n", + "Epoch 7, acc: 0.838100\n", + "Epoch 8, acc: 0.840300\n", + "Epoch 9, acc: 0.847100\n" + ] + } + ], + "source": [ + "epochs = 10\n", + "for epoch in range(epochs):\n", + " for X, y in train_iter:\n", + " with autograd.record():\n", + " result = net(X)\n", + " l = loss(result, y)\n", + " l.backward()\n", + " trainer.step(batch_size)\n", + " epoch_acc = accuracy(net, test_iter)\n", + " print(\"Epoch %d, acc: %f\" % (epoch, epoch_acc))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "for X, y in train_iter:\n", + " show_images(X.squeeze(axis=1).asnumpy(), 2, 5)\n", + " break" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}