diff --git a/RegressionScratch.ipynb b/RegressionScratch.ipynb new file mode 100644 index 0000000..c0ed99d --- /dev/null +++ b/RegressionScratch.ipynb @@ -0,0 +1,206 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 133, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "from mxnet import nd, autograd\n", + "import random" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_data(w, b, num):\n", + " X = nd.random.normal(shape=(num, len(w)))\n", + " y = nd.dot(X, w) + b\n", + " y += nd.random.normal(scale=0.01, shape=(y.shape))\n", + " return X, y " + ] + }, + { + "cell_type": "code", + "execution_count": 135, + "metadata": {}, + "outputs": [], + "source": [ + "def plotter(X, y):\n", + " plt.scatter(X, y, 1)\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 136, + "metadata": {}, + "outputs": [], + "source": [ + "def model(w, b, X):\n", + " return nd.dot(X, w) + b" + ] + }, + { + "cell_type": "code", + "execution_count": 137, + "metadata": {}, + "outputs": [], + "source": [ + "def loss(y_hat, y):\n", + " return ((y_hat - y.reshape(y_hat.shape))**2)/2" + ] + }, + { + "cell_type": "code", + "execution_count": 138, + "metadata": {}, + "outputs": [], + "source": [ + "def data_iter(X, y, batch_size):\n", + " indices = list(range(len(y )))\n", + " random.shuffle(indices)\n", + " for i in range(0, len(y), batch_size):\n", + " j = nd.array(indices[i: min(i+batch_size, len(y))])\n", + " yield X.take(j), y.take(j)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO2df3BV53nnv48BZZFIQRKKYoxkgZBxlQxVjexQDHYMeBu3HtxkBjfpTkva7iqeaSnNemcTx95N2jibTLu2l3p3YpONEzyTJoGmWXuo02IwxvyIHQuXUFsGS+KHAbvicoWpJTG+SLz7x73v0Xve+55f955zz7nnPp8Z5op7zn3f9xzpft/nPO/zPg8JIcAwDMOkk2viHgDDMAwTHSzyDMMwKYZFnmEYJsWwyDMMw6QYFnmGYZgUMzPuAajMnz9fdHR0xD0MhmGYquLw4cMXhBAtpmOJEvmOjg709/fHPQyGYZiqgohOOx1jdw3DMEyKYZFnGIZJMSzyDMMwKYZFnmEYJsWwyDMMw6QYFnmGYZgUwyLPMAyTYljkGYZhKsjoeA5P7hvG6HiuIv2xyDMMw1SQHf1n8M2fHcOO/jMV6S9RO14ZhmHSzobeNttr1LDIMwzDVJCmhjp84fbOivXH7hqGYZgUwyLPMAyTYljkGYZhUgyLPMMwTIphkWcYhkkxLPIMwzAphkWeYRgmxZQt8kTURkR7iWiAiN4gos2F95uI6HkiGiy8NpY/XIZhGCYIYVjykwDuF0J0A1gB4E+IqBvAlwHsEUJ0AdhT+D/DMAxTQcoWeSHEu0KI1wo/vw/gTQDXAbgHwLbCadsA/E65fTEMwzDBCNUnT0QdAH4dwCsAWoUQ7xYO/SuAVofP9BFRPxH1ZzKZMIfDMAxT84Qm8kQ0B8BPAPy5EOLf1GNCCAFAmD4nhNgqhOgVQvS2tLSENRyGYRgGIYk8Ec1CXuB/IIT4+8LbI0R0beH4tQDOh9EXwzBMWFQ6t3schBFdQwC+C+BNIcSjyqFnAWws/LwRwDPl9sUwDBMmlc7tHgdhpBq+FcDvA/gXIjpSeO8rAL4FYDsR/TGA0wDuDaEvhmGY0Kh0bvc4KFvkhRAHAJDD4bXlts8wDBMVlc7tHge845VhGCbFsMgzDMOkGBZ5hmGYmIkyyodFnmEYJmaijPLhQt4MwzAxE2WUD4s8wzBMzEQZ5cPuGoZhmBTDIs8wDJNiWOQZhmFSDIs8wzBMimGRZxiGSTEs8gzDMCmGRZ5hmKogCbnfkzCGoLDIMwxjkWQRS0Lu9ySMISi8GYphGAspYgASl4I3CbnfkzCGoLDIMwxjkWQRS0Lu9ySMISgs8gzDWFSjiDHuhFXI+ykiOk9EryvvfY2IzhHRkcK/3wqjL4ZhGMY/YS28fh/ApwzvPyaE6Cn8ey6kvhiGSSFJXvStZkIReSHESwBGw2iLYZjqpxTBrsbIlWog6hDKPyWiowV3TqPpBCLqI6J+IurPZDIRD4dhmEpQimBv6G3DA3fdmMhFX5Vqe+KIUuS/DaATQA+AdwE8YjpJCLFVCNErhOhtaWmJcDgMw1SKUgRbLvo2NdRFODJ3/Ah4tT1xRCbyQogRIcSUEOIqgO8AuCWqvhiGCY9SLVX1c1EJtj62MK3q0fEc7t9+xFPAy3niiOMpIDKRJ6Jrlf9+GsDrTucyDJMcgliqqmj5+ZyXyHkd1/so1ao29bOj/wz2Hs/gjqUt2NDb5jiWciawOJ4CQomTJ6IfAvgkgPlEdBbAVwF8koh6AAgApwB8IYy+GIaJliAbotQdsn4+57Wj1uu43kepm7dM/ahtNTXU4cl9w6Hv/o1jsxkJISrWmRe9vb2iv78/7mEwDOMTacFLYSz3/KDtlYqffio1ljAgosNCiF7jMRZ5hmGqnbgFOezJLihuIs9ZKBmGcaRawgXjjngJ2n8lx8u5axgmIcRtDZqIKytl0GuLO7Fa0P4rOV4WeYZJCEEFNWwB1oV1dDyHidwUNq9dUnK44LZDJwEQNq7sCDQRyWubyE2ivm6mp9jHnVgtaP+VHC+LPMMkhFKtwXXdrXhy33DZFr0+aezoP4MtewbxwF03lhwuuGXPEACgvm6Gq6jpE4y8toncVGLz2zsR9/qADos8wySEUq3BsEL9wgpPlKzrbsX+wQy6r53r2YY+wchrGx3Pob5uRuJTHagkrfAKL7wyTJUTVs4XdZJ5ct8wANg2/Tgtwjq9v3tgBAeGsmieU+dp0erXINvUxxAGfhaTy1lwTloOHhZ5hqlywk4h4BT5EfT9IGKnX0PQ6JMgouyn7XKiX5KQg0eF3TUMU0GS5q814eSmWdfdipdPZLGuu9V4vr42YHI/+b3+oK4i3UXi1o+ftuOO1gkTtuQZpoLEHc/tBydLdPfACPYez+DZI+/YrGZ5/u6BkdAsZN115GWh608Nbv34sbSTZo2XA1vyDFNBwrQQw3gqCNLGdMTLpNFqlhZ+mBay30VM/akhTZZ4ubDIM0wJlCqwYcZHhxHFEaQNe8TLzCKrOUgbgL97yGJdPizyDFMCSQiT8yOAXkJaiojKOHbZbpiZIIOM3Y1th05hy55BTOSm8MU7b6iKtZCoYJFnmBIIIrDruluxe2Ak8nQFpv68hDTok4XsYyI3aW10+sLtnbY2/Gaa9HLvlDeRCttrEibluGCRZ5gS8COOUlhePpHF3uP5+sVhpCtwElpTf+pkVK41Kysn7T2ewea1XdZCp95uuTnjJaU8IcixrO+5zuZS2tDbhoncJCZyU1blqqBU69MAizzDRIQaWrhi8UhoyaukSKpCa+pvXXerTZTK3RmrVk5Sc9Ho7XqN3694l7J+4TSBNDXUob5uJr75s2OeKRaCtp10WOQZJiJ037Vf688rvlyvYKTTWJ/vV1rdgF18S81149Rv3kqewkRu0tdmpKCLr0Fwe3IpdxG3WheBwyr/9xSAuwGcF0J8vPBeE4AfA+hAvvzfvUKIi2H0xzDVQljWn96OlxsEgK1eKVBarhtdKJ0mHwDYsmcI9XV5SZFuo0fu7XEV77CtYz2+3pQPJ4y2q4mwLPnvA/jfAJ5W3vsygD1CiG8R0ZcL//9SSP0xTFUQlvWn7jZ1sn5NfZW64xOw++CBYhEeHc9h09++hoPDWfStXmRzHe0fzGDv8Qye2DeM5oIVrebAkQuvE7lJbF7bFYl1XK2Wd9iEsuNVCPESgFHt7XsAbCv8vA3A74TRF8NUE6XunNRzscjdprsHRhx3c6p9mfoNmvRL9cHricOkUB8czgIAZtfNtPW9/PpGAMAb5y4VjVWO/+GdAwXrf0YkC5lp2rVaDlH65FuFEO8Wfv5XAK2mk4ioD0AfALS3t0c4HIapHnQ3hm6Vqj5wvwWx1TblWoFTeGe+YEjeylYXWdUInk1runBrZzM+dt1cbFzZYet/48pFqK+baWtfUuqCdFCqNRombCqy8CqEEERkrBguhNgKYCuQL+RdifEwTNLRRV33B9fXzShEisx09c9P5KasfOxqm17hnbLgh14wZF13K7YXLHwAODicxW03tBSJqDreztvnALCLrn7MhHq+HJPJ7eMk4tUaDRM2UYr8CBFdK4R4l4iuBXA+wr4YJlV4LfL5DVNU88wEsaad2t89MILhzDg6WxqwaU0XVixuDj0Pjel8AEWf9WqvVJ+80+Si/lxNTwZRivyzADYC+Fbh9ZkI+2KYmsBkDZsw5ZlxitDpvH2O5WuXAuYU/rmht82y/l89Neo4BpOVHXRjltNCsttx0z0IitPkov5cTU8GYYVQ/hDAJwHMJ6KzAL6KvLhvJ6I/BnAawL1h9MUwSSAuf29Qa1gVOl0U1Wsw+esnclPYsmfQ1ldTQx0eursbwABu7mjCk/uGjX590zhN4Y2qO8nN5WO63qhCGv1OLtVCKCIvhPicw6G1YbTPMEkjLn9vOWkKdFHcdugktuwZwkRuEhtXLrLalcm9+lYvNlZ2kpE+QD4W3+TXd7Oy1UVdQISapCyMyddtcqkmC17CO14ZpgS8XAVRWfpum32cxrHt0EkAZEXJyLFdvnK1cBZpn8rHP8yuu8YYGy8Fen3PAqxYPIKbO5oAwFYxShdKNTb+4Z0D2Hs8Y00gao4ZnXL8+F4VomoFFnkmkSTpy2kai5erwG+SsaAZKlX8LCzKKBkAlltEbnDqW70YdyxtwfqeBbZom4fu7nYUXjXqprNlDjpvn4Mn9w1j7/EMViwecYyW0aN5ZOx9uQvMXudzhA2LPJNQyrHIwp4gShEKL3FyC2H0O34/Pul8XplJAGT52qXIzq67xhJndUF1xeIRx7GYrsup9mvxOKZwOTeJZQvn2WLv9X7k/fEzCTjdk+HMGO7ffgSb1nT5LiaeVljkmURSjkXm99ygRaWDJPbya6GaQhhNi6Cl+qSbGurwxTuX2vqVor++Z4G1YWlH/xk8dHe341jUBVL9uqSP3s2Sz2eBnIEtewaLYu/1awbg6/rdfn/SJQQA3/vDWzzvW5phkWcSSTk1O/2eG7R+qJcPPMgThGmzkGn8QQXfT5EQPeWu23VJC/zlE1m8cnLUqrTkNF433BaNnSJa3K7H7ZiMAMq/1jhCiMT8W758uWCYKMiOfSCeeHFIZMc+ML5nOu6nDZUnXhwS139pp3WOV3t+x/joruPi0V3HxKO7jlvtOzF0/n3x+adeEUPn3/fdj9dn5HVd/6WdYsO3DxqvKej1qvfKDbd2vX6ntQSAfuGgq6EkKGOYpGNK6qUmsHJK+qXilfBqQ2+b5f/ddugkvvmzY4XIFmf0hF+mMebdHEMAhKd/WU1k5vc65Gce3jlgJURTx7Whtw2rlswHAPzi1EXjPZJjv3/7EV855eW9ki6w0fFcUVI201jdrkMdh9vvUcXUZ1hE2XYQ2F3D1AR+0wCUs0BndzGR9mpGj2oxlahTx+bmBlLDG4PE0auLrjv6z9hKCQJ5V8jffO7X8cS+Ybxx7pJxgdXUhml8+oKq6ioCyt9RGvT3GGX0TVIie1jkmZrAayG01CpOTmxc2WEtVLphT/g1gGUL52HLnkFbiTp17MOZMTy8M+9r7myx+/L1pGJBioMsWzgPyxbOLRJJdbKYPesaHBzO4uGdA8ZiIF0f+TCuTF11jLIxiZ7b7lKnScot/NSrqlapufV1Sk3LEAcs8kzN41QY23SO+qV2+6L7Df1TE37tPZ7BsoVzsXltly2NsNrPdNTIQFHUiEmg1Vcn8pODPerFtOlq89ou3LG0xWit7+g/g637T1jXZIqyMY3HbXepkyXslUHTdH1Ok10Y+W2cPp+USlIs8kwqCRLpIr+wemFs0zmAmtvFeVLwi4wx37SmC6+eGrVFlMg0wmrfblEjbqJSjjWrhnsCsFn86nW89FYGH7tuLtZ1t+Kx549D3WXrNT4TN3c0obOlwdpRaxqPn3z0UVjrSbHS/cAiz6SSIP5QPz5vU1ij26QA+Jtonj3yTsGCn2eFJuquEtXP3tRQZ1nwQfKpA2Z/t58xqv5zafEDsO0Z2D0wYuWW3z0wYttlW+oE+PgLgxjOjOPxFwZtTy1u4adu4w9KFE8AccAiz6QGVbCCWFp+vrBO2RzdnhLUBGDqhiQ7+Twxl3NTlmiqSD/75rVdRWIcNJ+6qZpUkJ3FTvH7auWqmzua8Miu4/iD37gejfV1gSZA/b24Y92ryVp3g0WeSQ1O+dLd0IUliGXrTT6y5vKVq447ZWWZPFncYyI3iaNnL1m+5mmBnixyF0n3idNCpZ9qUrqQ+bVe9c/JY3/4vV/g4HAWdTOvwV/e83HHO2PqR3+vs2VOrLtVq8lad4NFnkkNpVheurD4cfP4mQjysdGiKJ2uHsGjF/eYyE3ZEnhdnMhZPnun4h9O16KP1c/CZ36H63QYp2zXZHGv627FtkOnAAhsXLkokPVtit7RJ61S7jtTDIs8kxqChs8BpUWk+JkI1HDGDb1tVo4YmQFyIjdpibYu9kDefbPt0EkcPv0eDgxdAJDPwZL30U+hb/XiIteLuiApnxyCPN3IewWQFcYJOJfdU6Nbjp69hEfu7UFjfR1WLG5GY73/tA5+Qz2TEndebbDIM1WLn3S9fnK5qO/7zewIuCcs29DbhuxYDvsHM1jX3WotXEorHSBHqxuAFY547/KFVj1VeT1b9gxaoYyq68W0wSjI0830gvKSogVl08/ruluxbOE7OHx61AqrBIJvaHIbY6nrLMw0kYs8EZ0C8D6AKQCTQojeqPtkagMpSi+9lcHB4SyyYx+gec6HbKLrJw2ujpf17ySm+mf2HMvHwD+8Mx/TrouUvllKFdlVS5pxYCiLdy5dxnBmHK+eGsVN1zfaQi6dimjrC8OliK3qmnG6fgD44p03FKUJVtsqd42jlHUWxk6lLPk7hBAXKtQXU6UE9blKIcmO53BwOItfnr1UlCnx2SPnCiGK5xwjXPR+/boF1nW3Yv9gBtmxnBVJI8d1//Yj1ianTWu6LItfbU+6bx66uxuN9XVWqOTGlR3YuHJR0VMKoKb1bbbcO3pMui6apeanH86Moe/pfgxnxl3vRSkbmvzC1nv5sLuGiY1SxVWi+rGbG+qQHcvhlZOjkGGJAKwSd9Ol7opxCgfUhUUf7+6BERwYyuLAUNYWM7+jf7owx0N3d9tym6vXJd+/MvUGrkxdxSsnR9G3erElxGosuEx2pS5Ojo7nsOlvX8PB4SyA4ph0Od7sWA5b958whnJ65WSXE5WXK8WpAIhfkXZqK8iTCC/MmqmEyAsAu4hIAHhSCLFVPUhEfQD6AKC9vb0Cw2GSgl9x9cIm9nPssdmzZ11jezUJgVM4oNd4pSXffe1c285OPZ5cjZZRkZEoXR+Zg63789kqdxw+g9+9pc2Yl0afAJ/cN2wJ/Kol84tcJHJH7qolzYVWipOlOcXJr+tutXLR/MU9HzcW7JCLyOqYTO/7WQwPY1GVF2bNVELkVwkhzhHRRwA8T0THhBAvyYMF0d8KAL29vcKpESZdyCiRzWuXeIqr3/ZMVpyMQ3eLA1cnCWktmxZy9clAWvKru1psPmxZOFv/jC6UMg487/cm7Dh8Bhcnrlg+fBW97+kom0UACLMLkTDqNcqnC93l49auHjkja7nqOE1ebpOaPj71d1CuW8b098TkiVzkhRDnCq/nieinAG4B8JL7p5i0Y0qKVW57JivOFAeuvkpGx3P4sx++hgNDWewfzODAULaktmToJABczk0WLQTLvvRY+a/89q/id29pszJM6uh9yzWAzWuXQIY8AvmFUNPE4laWz3RN+ciZc0Vpj03X7jYRmnLY+InXD0rYf09pIlKRJ6IGANcIId4v/PzvAfxllH0y1UHYC2p+23NzxUhh7752LlZ3tXjWdDW1taG3zZokdg2M4FR2AoC/xUinHZ6mUFF1rUG6ogDhe6OWfNrYuLLDGpMejaOWCXTaxOWGOuE5pU4OC16gdSZqS74VwE+JSPb1t0KIf4y4T6YKCPuL7qc9r7wsE7kpqLs3g+RjV8fxN5+7yfJLq24Lt52dbrlczl28jKdfPo1/fP1d/POZS/lsjwt+BUB+rUF1SXn5pXWfOQAcPn0RB4YuFNVv1dcWpBvHlEve1Hb+nuaLhkctvmlJQRAFkYq8EOIEgF+Lsg8mmZSSEybMfkx+da+8LKUWqDbR1fph5Cav4qG7u439S+GU48yOfYCt+0/aImDk+R3N9QCAC2P5mPWDw1n0djRZ0TymnDJOTyHSZ35rZzN6O5oACGtHrRqVJO+J2q5X5SfdH5+/p/7CVpno4BBKJjB+vqCl5IQpBad+TAUlgoq2viDr5QKR92TboVPY+lJxAQ1VgKXFu3/wAg4MXcBN7fMAAJdz06Ge8vybO5rw+AuD2LSmC/veyuBy4YnDbWJzut9yM5WsLDWcGcPh0xcLEUKLXO/FI/f22DY96X8HbovMOhwJUzlY5JnA+PmCql94PSe6F0GsPF24VSFVC0p4pbaV1+X2BKC7KtTPq/dEWsS3djZbFrUuwHuPZ7BqyXxcmZoCAMyakY/GUaNkgHz2yX1vZfDQ3d149sg7AARm180o5JaZ6TixOU1o05up8pOPKULICd0lov8dBHGZuO1EZis/XFjkmcD4sYj1BFRq7VEvnCYR05ffLfeMGlHildoWgE0oVUF3clWo4iojYqbvCQEQePbIOWzZM4Tt/WesnaPyHBnHfsfSFmxa04XHXxjE+p4FSpz7lLVwefTse5aA67llTBObHLfXngDT79JNZJ02Ovn5jDrR6ZONn98/Uxos8kxggi5yefmJnc7XJ5Eg4u+nTdWazKcVmMLF8Q/Q0VxvE3Tpqth26JQt86Mq/lJc5Thk7vZVS+bj1s5mHBzO2nzVanph1bpfsXgEAKwcNjJEcn3PAixbmLfk5cKwTmO9u6UNeCdkc9rk5NSmn0ySQZ40vI4xwWGRZyLHlNDLbZJwmkS8xH8iN2VL+qWH/OmWrSy919V6Bs0NdQAEnn75bQD5HaSqoDc11FnCffTsJTx0dzd2D4zgobu7LYHfdugUtuwZxERuChtXdliitnltF267oaVoEjItmKrXpp+vJwPT3Unq9es+cjd0t5PbZianNv2Itv6k4TcXDlMeLPI1TKV9n+VaaF7ir1ZPApzzoEsxvJybBAC8ce4SDg5nsWpJM/pWLy74xQW27BkqqqI0bY0OFLl1pqNTRNFCpVwg1Z9k9LBKeZ3qhCTH7rSgql///sELWH79PGxcuch4v4YzY9amq86WObYyhXLx1elvwskN5Fe0/dRkZcKFRb6GKcX3Wc7EELWFtr7nOlsKA2BaAIczY3jprQz6blsMWalJbvu/uaMJ//p3v7QWIKeLdxAmcpMYzoxZvmQp3HlRHrC5dfQUCjrq/V7X3YqHdw6grXE2nn757aJdtvKpYP9gBsuvb8KWPYMOvn97FJAsHXhg6IJtglKZTpgm0yfInDbke5NT1D5zXnwNDxb5GqYUyzqJX3BT/PmGXnta34d3Dli1Rx+5t8dWlemx549jODOOm9rnYv/gBazrbkVnyxybe0a12mUfqqsGKA65lAurQHECNun3bmucDQBYPH8OVnepLpL8U8GBoSy6r/0VdLY0WD57p7BOde1Ahlia0Ev0re9ZgKNn38P6ngW+fi9eZfrCgBdfw4NFvoYpxbIOc1HMScxNX3Cv3ary1UnwVWErvu68JXvywoQtQZgap/726IRr9SNTOKWafhiw3285nramejz989P4d7PsoZN5t0k+QgcgDGfGbX5yp+yRG3rbijZ16ejpE9wiXVQqKby8+BoeLPJMIMJ0uTiJhukL7rat3mkBU2/flBcGyFuyO4++Y+VOlxauumCsiyxgrqWqu1ScnkTUDJTXzZuNidxUUcy5XGjdduikVUzEKRTSjwA77RXwu4ehksLLi6/hwSLPxIYuGnoMtppES49Vd1qY9Nu+KnJffeYNDGfGsWrJfPzFPR/Dj189gyNvX0RPeyPu09wsqshOF+XO70BdtWS+VYWqvs7+1XJKt2APp5xRdA0yyZe+x0D93P947k388sxF9K1e7CrApolStr95bZeni4yFtzphka8hol7MCtq+2w5KAEUZDNVoFdNGJv1nvWCF04Yombul+9oP26o4/eLURTQXxqhHlKihhjISZ/PaJVjdNb/IKlf7NqVb8OuKMrGj/4yVQqG+bqZrniB9olQLn7x84gJeOXnRNi4mHbDI1xBh+FTd0gPkd2gOBmrfrVScnsFQd8tM5CYxkZvC+p4FmMhNIjuWAwi2whGqBW0qKpFvZwqXc5MYePd9HBi6gE8saoIQAj3tjda5Mszw3MUJnLl4GRt/o8Mq79dYX2dbyDVZ5ap/X77qrh79vvnd5JUdz+GNc5csN5FT3Lwp/4xMawDAtdAHU72wyNcQYfhU3dID5Hdodtk2EZlwyvmiW95qBkM1B/r6ngXYPTACWSzj6Nn3sGzhXGzdn7doVdeGbkGrx+Q4Nq7sKFj0J3HH0hbLlaGOUyYO23s8gzMXL+OXZ9/D6PgVK9rFa1KT8eXqblKv3Z9ek7Ic330FH77EtG/AacernOTcdtIy1Q2LfA0Rhk/VJEimxU6nGG2gOORRb8+EtKSB6TwuN7XPQ2P9rEIxbIG+1Ysw2yFOXlrQ6qYjfRxS7EzHZeKwO5Z+BAeHL1gLsXqqBjUtgFyE1ROTyScAffenHhbpx1XjVg1LTZvghCnFMpMuWORrkHI3NOn+aa/t+Xq/pjjrixPTx549cg56yTgZ5tjRXI9Na7oATFvCTQ2zcGDoAlZ3zTf6uWWEjB4mKIV9IjeJixM5HD49arkuvnjn0qJrka6PP7/zBmusqg9fZoTcezxjxbXLDVP6vWlqqHNMzDWRm0R93UxjhkYVdfzqhi0/u1BLhTcpVR8s8jVIub75Un3Ips/JfDZq9kcpimrJuI0rOywLft9bGSxbOBdtTfU4kRnDf1y1GNt+fspopasbmIDiJGnqhicp8AAZr0WPRzfleVH72T0wUpTyGJie0EwLowCshVvTIq2KPn63c8OCNylVH5GLPBF9CsAWADMA/F8hxLei7pNxx4+LpJSIDy8BcHP1SDGUBaTVBVc5lk1rupCbvIr9gxm89vZ7uLWzGY//3k22DI6NvfnIl6WtH0ZTwyxbJIkpSZraf1drPnTy8pUpm3tIFXd1F6tumev3TFrqnbfPKXLjuEXYAPkJzjRBON3T/H17x3M9pBTcFseZ5ENCCO+zSm2caAaAtwDcCeAsgFcBfE4IMWA6v7e3V/T390c2HsYfqiA9cNeNrhabuiB6+w0tePyFQSvxlXpOOY/4UpjvWNpiq00KwPJvSxePDGdsapiF0fEr6GxpwI77VhYttpqig2QJPgDYvLbLFpkix6C/bxqn6Z7JY6uWzEf3tR/G7LqZ1gJyfk9APlfN5rVdJfvI3fovh6jaZcKDiA4LIXpNx6K25G8BMFSo9Qoi+hGAewAYRZ5JBl7pZnWXhb4gqm+P9/uI7xbfDcCy8i9fuQoIYODdS1Z/9XUzLRGWk0FnSwP+22934/7tR6yJR64pyPwuG1cussa3asl8APk0w9MRPPYxeIUzytf85DfdhxrxIjc37R4YKRGHvXgAABUiSURBVKompddZDUJUVjZb79VN1CJ/HYAzyv/PAviEegIR9QHoA4D29vaIh5NckrSg5SVoxVEp+Xj29T0LjO4FvyKhpwZQFxLl5KCGVcqUubLoBzAdPbNsYb5m6XTI4oCV1iA/MeXj+WX0iYyVX359Izau7LCl3/3inUtdFzH1RV5AVsOa7kO6i0wRL9MuGvv7Xn8T+vGodqTyTtfqJvaFVyHEVgBbgby7JubhxEbYC1pRpgQujjqZ3ghkSnDlFCJoateUr10XfImaWEv1uW/ZM2jFw8tEYJvWdFm1Vidyk+i7bTFmz7rGtgC7Zc8g+m5bjPu3H0FbYz0A4PDpi54+btPvTg3JlJa9KTum089qu/qmJrd+GUYnapE/B0A13xYW3mM0wn4kjlIA1EngseeP2yzeUsckRVDGj6v52t8e7bfqo3olMtPfa6yvw4rFzdj31nls2TNk2xRlauultzI4OJzFrZ3NlttHLtw6J/gq3knb1FBnbbJSr91JsPX7oC5yZsdyVsUp1V/v9DejrpPYw1CZWiRqkX8VQBcRLUJe3D8L4Pci7rMqCfuRWK1fGhbmpwPSXr1966aJzDQBPHJvj+VuMa0PmJ4QnNpVC4QAKLov0lefHcuBCPiLez6Oxvq6oiRoehrj7FgOW/efwOa1XUViatr0ZdqF6nYfvnB7Jx57/njhqP1B1+lvRl0nUcNQmdokUpEXQkwS0Z8C+CfkQyifEkK8EWWfaaFcH70pR3i5bUoBUsvLbVzZUZSnxWsnpgl1cTUvarJ49VwsWzjP1SJ1Sla2rrvVqgYlP2/aFKW2s3X/CTxw141WdJDTk8P0Ym1z4Wixp1FOtDd3NCkbvd4psvpN90E97lVxytSGnveHqV0i98kLIZ4D8FzU/SSZUsS1XHeLSSzCaFO6O9Tycnpbfne9mlLu5n3q9midzWuXuN4/aYHvH8zgP9+51CrUcf/2I1Y1KKc87GoUzPqe63yNW921q16Hzo9fPVNIRTyJV05eNObP0TFNhF5PeaYFWC/XGVM7xL7wWguUIq7l+uhNwhBGm37Ky6l96wLklnJXjk2P1lF3gOoFQ2R/g+ffx4GhLGbNGLQyLcr0BzI7o+m+6JE2Xvl2TON2qqb0xrlLhZ8ID9x1o6/NTaUQ5vpLkqK8mJAQQiTm3/Lly0UayY59IJ54cUhkxz6IdQyP7jomHt11vKLjeOLFIXH9l3aKJ14cssbxxItDYuj8+77vSXbsA/H5p16xtaMzdP598fmnXrHadTtXb/vRXcfFo7uOFY1F/b2VMm51TKY2w6KUNk2f0e9zEv5uGX8A6BcOuhq7sKv/0irySUAKn1/xMxH0S+8moH7aM4ms/rPb57wE2UvEg0wWfnFrs5KiahqHfO/zT70S2fUz0eAm8uyuqRGcFuOCPJ57JSZTk3LtHhixiog8cNeNAGAtqMpFUC83g74hSS54ynbd3DemPDVO1/PyiSyWLZxntalmjZzITSE7lsNjzx93zbeu3gOZYkFNW+C0JuDn/noRdNOU0zj0TXC80zUdsMhXMfqX1+3L7rQYF0RYnL70ur9avm5eu8RaBDWF9XmLiD08U/bTt3qRlcr3/u1HioTeLaWxflyOddnCuUV+c9mmLEYS1GcvF45l6mB9967f++uF1+/QdNzPAi/vdE0HLPJVhNMiJjCdyzzKBV5TLnn1s2omSfkq87+YniS8REQPz5yONZ/CcGa8SOgB2Cx9p/ug3ie1HJ68HvtCaj408tbOZtd7pN4DmUVTXzh2Go+kFFF12oxlGhtb5DWKkx8njn/sk3fHaRFT+nBL8ZkH9QEH8dOaznVa8PPznn5MLmzKPmR/v7f1564LzEGu22scbmsOYfdngv3mjBDsk08NukVW7uN1mJa/TBamphlWzzXlY9drxHq9Z7rOfEjnSaugt3SX3HZDi2N6YT/3ySnXTHEGTu/wyyC/l6C/E7bSGS9Y5KsIv2LhdzHVz6YlvQ2nMUyXwZvO9qieq+Zjl356t3H4Fa/82MgSW9X9IillMnP6THEGzukkZGEQVLTD/ptg0geLfArxK2puAmHK0+K2wCuzPaobj1T0yA0/45jI5Ss0uUW15JnOxW5aN1D95WqOG69dtOqr03WEUQRbv59RLHZyxsrahUU+JagRIxO5KfStXoSJ3FTJpeBMeVoA5wXezpY5lgUfBn5dIUBxbpfpjI/2gtjPHjlnZaJUQybdwjB1ohDhSggwu3VqFxb5lKCH8Mk0uW5ZCL1CLtXc6MC0JewUmuiEWk4Q8Bay0fEcsmM5fGJRI35tYaMvd5KajVKOTy+IrVaNWrZwblEq4SD3J0wqIcAcDlm7sMhXEX6Ka+thjE4Cue3QKRw+PYoDQ9miOG4d6fJQhVp1iQDurg+ncoJO1yOzQQLAmhtbfVWnMi3q5qswzbDdE2C6apQ6btN4KuXiYAFmooRFvorwG3HS2FvnKbrSFZJPlUueYqYLtbobVdZXdfq87sc2Rdqok4afBU2TO2nz2i5sXrvEclNJGuvVDJeDtiyQTouq/jZrMUzyYZGvIvyKjpcFuq67FT/8xds4lZ3A8uubjDnh3fqWES15yDgu1SrW2XbolFVqT90Rq1ZO8qpoZHInqe3U180AAKtNQODylauOm4ZMm4qcNn8xTDXBIp9wgkReeG3nl+weGMGp7ATuWNpiiamXu0A/R50YTJ+354WZa3OjXM5N5k8SwhLRaT+6e+Ukr7Hd3NGEzpYG3NzRhI75DVabsn9pxZtcM7qVr16H3/EE9eNzaCMTNSzyFSbolzqIyPg91y2cMSw29LYpeWHm2WLjZ9fNtF71MQ9nxnD07KWSyxY+/sIghjPjePyFQXzvD2+xfPN5pp86ZL/ZsRwGz7+PTWuK4/fldaivXgSdFPTzWfSZsIlM5InoawD+E4BM4a2viHyVqJomqAj4rdXqJ4eJxK2oh1+8rsNUYERa0ADQt3oxAGGrsPTkvmFM5KasTVUP3d1ty+LoZ6ymeP18Ue1F2NF/xnpP3qP9gxkcGMoCgDEENOiiaNBJQT+f49mZsInakn9MCPE/I+6jqggqAqZarRLTFnu30nImShUVP9fR1FCH+roZ+ObPjuHo2UvWbtQtewbxiUVNeOXkKC7nruIrv/2r1o7YP1jRjqaGWZbQy6yW8rNuYx0dz2H3wIgx7l3/rBTvdd2tVjqGMAg6Kejn82IvEzbsrqkA5exodPvS62GETueV2r4bfvO/TOQmsWrJfCseXfbzwrF8dsqBdy/Z+t8/eAGj41fQ2dJgWeXys15PNW6TgNN1hr2Jq1w4nJIJHafMZeX+A/A1AKcAHAXwFIBGh/P6APQD6G9vb484V1s8eGUKLLUiUNyl5LwqMMnrfnTXsaLjptJ48v3/8J2fi2/sHCiqAuV0H0spzef32rj0HVMNIKoslES0G8BHDYceBPBtAF9HPrHI1wE8AuCPDJPMVgBbAaC3t1fox9OAl7VcqsskLKtPd/v4zVmj77LVN1W5LfA6WdCdLXOwuqsF3/zZMTTPyV+fKVRSHXuQ3bR+70fYbTJMXJQl8kKIdX7OI6LvANhZTl/Vgmlx0EuMo/LD+l1ULc6sOFlIDnbKlt9Fn4yk+2TTmi6sWNzsuziG17ic7odTqKZpN205RNEmw8RFlNE11woh3i3899MAXo+qryRRilUed+ZB3eKWO1g3r11iy++ii69cFF62cC7q62Zifc8C26Yq1dJXF0P9ROb4zfeu7gkIK+SwEiGmDFMpolx4/Ssi6kHeXXMKwBci7Ksi+LGM47bKwxiLKazPTyIwAEWRItKVoyYB09v3c21+8sqEFWPOi59MmohM5IUQvx9V23HhxzKO2yoPMhY9h4zqU1c/57SbVR7Lb2B6ryjqRcbK6+kN9HHp7iIZW6/mkfeTV4ZjzBmmGA6hDECcMcxR9K0m9nrgrht9+9Sd3DZ6LH8pFarU5GkArHzwprwyfsMkGaaWYZEPQByP8U61RoN+1muBU0bPeCUqA/xv4DFt2TdZ6XqyMZlQTGbHlC4fr41ept8Ppwlgah0W+YRTjgsi6AJnqZOY0+dMvn2vak9qST09H/yG3rbAos0uHKbWYZGPmVLDCf0Qt/vCZPEHKXytfl66gWT6A8Au2k73Me57wDBxwyIfMn4tzelFzynLuvVrafrto1TLPCoXRxiFr/26htQ+2YJnahkW+ZDx6x6Q561a0uyaOdLUXtQuCL2AR5L82X5dQwzD5GGRDxm/YqPGkK/uagkUd1+uoJksdfU92a5TAY8kLmayxc4wZljkA+BH3PyKjVMMuZ/+yhU0t6cDNVYegO1nt887kcQJgWFqCRb5AITpJvEjfmH0Z+rH9CQg89Bczl3Flj3HipKNqQR5kuDoFoaJFxb5AITp9/Ujfk5ZF8sNITQ9CUznoZnnuTEqyJNE1L7yJDwpJGEMDOMEi3wAwvT7+hE/twLZgD/LOMgagXwNsjHKC6d7FpYwJuFJIQljYBgnKJ9vPhn09vaK/v7+uIeRaNJiNcp49wfuurEsYUzC/UjCGJjahogOCyF6jcdY5JNJkoWjlKyRpbQRxbgYJo24ifw1lR4M4w/pAtjRfyayPmTK4NHxXKDPOY1Nbc9r/NKNE6YYV+KeMUy1wT75hFKJzT2l+pL97DqNY3MSb4himGJS465J4qN6GGNyytyYlPFF2R7DMP6oCXdNEh/VwxiTzNy4Zc+QsZ1SXS6A2WUSdnsMw8RLWe4aItoA4GsAfhXALUKIfuXYAwD+GMAUgD8TQvxTOX15UYlH9aCWahhj8srcGHb4HocDMky6KNcn/zqAzwB4Un2TiLoBfBbAxwAsALCbiG4QQkyV2Z8jlchdElQA9TGpxad3D4z4miy8MjeGPbmp7bH7hWGqn7JEXgjxJgAQkX7oHgA/EkJ8AOAkEQ0BuAXAz8vpL27KFUA5ScjEZID7ZBFmrhy/qO099vxxq/brF+9cGlofDMNUjqiia64D8LLy/7OF94ogoj4AfQDQ3t4e0XDCQRVAp+IVbshJQq105Eb8rhPSXhmGqTY8RZ6IdgP4qOHQg0KIZ8odgBBiK4CtQD66ptz2KkUpbhJTpSM3ZNKwdd2tpQ2yTDau7AgltQHDMPHhKfJCiHUltHsOgKoMCwvvpYZKrAHIpGErFo/4mhTCJug16u4l9ukzTPxEFUL5LIDPEtGHiGgRgC4Av4ior9SyobcND9x1Y2iWdDnhkX7QQ0aTGNbKMLVGuSGUnwbwOIAWAP9AREeEEL8phHiDiLYDGAAwCeBPooysiZI4rdGwnxai9vHrLizegcow8ZOaHa9REVa2xCTA7hOGSSduO15Tm7smrCyIabJGuQ4qw9QeqUlroOPlD/brL+at+gzDVDOpteS9LPA0Weh+YFcNw9QmqbXkvSzwuCz0qCNcnOBIF4apTVJrySeVuHax1tqTC8MweVjkK0xcYsuLrgxTm6TWXZNUynETxeXqYRimemGRryJK9avz5MAwtQu7a6qIoK4eGVEzkZvClj2DALgQCMPUGmzJVxFBXT3Ti7wi1Bw4DMNUD2zJ+6Qa48xVy79axswwTLiwJe+TcuPM4/CLNzXUYUNvG3b0n2F/PMPUKGzJ+6Tc0Me44uPjry7FMEycsMj7pNw487ji43kTFMPUNpxqmGEYpspxSzXMPnmGYZgUU5bIE9EGInqDiK4SUa/yfgcRXSaiI4V/T5Q/1HTBG5QYhqkE5frkXwfwGQBPGo4NCyF6ymw/tfCCKMMwlaAskRdCvAkARBTOaKqIcuPmvRZEqzEun2GY5BGlT34REf0zEe0jotUR9hMKQd0n5cbNe+1e5fzvDMOEgaclT0S7AXzUcOhBIcQzDh97F0C7ECJLRMsB/D8i+pgQ4t8M7fcB6AOA9vZ2/yMPmaDuk6hDEzn0kWGYMAglhJKIXgTwX4QQxvhHr+OSOEMo2T3CMEy1UvEQSiJqIaIZhZ8XA+gCcCKKvsKC87wzDJNGyg2h/DQRnQXwGwD+gYj+qXDoNgBHiegIgL8DcJ8QYrS8oSYX9p8zDJNUyo2u+SmAnxre/wmAn5TTdjXB/nOGYZIK564JAa6fyjBMUuG0BgzDMCmGRZ5hGCbFsMgzDMOkGBZ5hmGYFMMizzAMk2JY5BmGYVIMizzDMEyKSVT5PyLKABgHcCHusVSA+eDrTBO1cp1A7VxrNV3n9UKIFtOBRIk8ABBRv1OinTTB15kuauU6gdq51rRcJ7trGIZhUgyLPMMwTIpJoshvjXsAFYKvM13UynUCtXOtqbjOxPnkGYZhmPBIoiXPMAzDhASLPMMwTIpJnMgT0deJ6CgRHSGiXUS0IO4xRQUR/TURHStc70+JaF7cY4oCItpARG8Q0VUiqvqQNB0i+hQRHSeiISL6ctzjiQoieoqIzhPR63GPJSqIqI2I9hLRQOFvdnPcYyqXxIk8gL8WQiwTQvQA2Angv8c9oAh5HsDHhRDLALwF4IGYxxMVrwP4DICX4h5I2BRqGf8fAHcB6AbwOSLqjndUkfF9AJ+KexARMwngfiFEN4AVAP6k2n+fiRN5IcS/Kf9tAJDalWEhxC4hxGThvy8DWBjneKJCCPGmEOJ43OOIiFsADAkhTgghcgB+BOCemMcUCUKIlwCktlYzAAgh3hVCvFb4+X0AbwK4Lt5RlUciy/8R0TcA/AGASwDuiHk4leKPAPw47kEwgbkOgFrB/SyAT8Q0FiZEiKgDwK8DeCXekZRHLCJPRLsBfNRw6EEhxDNCiAcBPEhEDwD4UwBfregAQ8TrWgvnPIj8Y+IPKjm2MPFznQxTLRDRHAA/AfDnmneh6ohF5IUQ63ye+gMAz6GKRd7rWono8wDuBrBWVPGmhQC/07RxDkCb8v+FhfeYKoWIZiEv8D8QQvx93OMpl8T55ImoS/nvPQCOxTWWqCGiTwH4rwDWCyEm4h4PUxKvAugiokVEVAfgswCejXlMTIkQEQH4LoA3hRCPxj2eMEjcjlci+gmApQCuAjgN4D4hRCotIyIaAvAhANnCWy8LIe6LcUiRQESfBvA4gBYA7wE4IoT4zXhHFR5E9FsA/heAGQCeEkJ8I+YhRQIR/RDAJ5FPwTsC4KtCiO/GOqiQIaJVAPYD+BfkNQgAviKEeC6+UZVH4kSeYRiGCY/EuWsYhmGY8GCRZxiGSTEs8gzDMCmGRZ5hGCbFsMgzDMOkGBZ5hmGYFMMizzAMk2L+P3eXA20GsgiXAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "true_w = nd.array([3.4, 5])\n", + "true_b = 2.5\n", + "X, y = generate_data(true_w, true_b, 1000)\n", + "plotter(X[:, 1].asnumpy(), y.asnumpy())" + ] + }, + { + "cell_type": "code", + "execution_count": 145, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "[[-0.0042419 ]\n", + " [ 0.00958158]]\n", + "\n" + ] + } + ], + "source": [ + "w = nd.random.normal(scale=0.01, shape=(X.shape[1], 1))\n", + "b = nd.zeros(1)\n", + "print(w)\n", + "w.attach_grad()\n", + "b.attach_grad()\n", + "epochs = 10\n", + "batch_size = 10\n", + "lr = 0.1" + ] + }, + { + "cell_type": "code", + "execution_count": 146, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 Loss: 4.877576066064648e-05\n", + "Epoch 1 Loss: 4.792332401848398e-05\n", + "Epoch 2 Loss: 5.0487014959799126e-05\n", + "Epoch 3 Loss: 4.844959403271787e-05\n", + "Epoch 4 Loss: 4.912050280836411e-05\n", + "Epoch 5 Loss: 4.8942867579171434e-05\n", + "Epoch 6 Loss: 4.864722359343432e-05\n", + "Epoch 7 Loss: 4.856755185755901e-05\n", + "Epoch 8 Loss: 4.9007012421498075e-05\n", + "Epoch 9 Loss: 4.856191299040802e-05\n" + ] + } + ], + "source": [ + "for epoch in range(epochs):\n", + " for features, labels in data_iter(X, y, batch_size):\n", + " with autograd.record():\n", + " y_hat = model(w, b, features)\n", + " l = loss(y_hat, labels)\n", + " l.backward()\n", + " w[:] -= (lr / batch_size) * w.grad\n", + " b[:] -= (lr / batch_size) * b.grad\n", + " epoch_loss = loss(model(w, b, X), y)\n", + " print(\"Epoch {} Loss: {}\".format(epoch, epoch_loss.mean().asscalar()))" + ] + }, + { + "cell_type": "code", + "execution_count": 147, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error in w is: [[ 9.4437599e-04 1.6009443e+00]\n", + " [-1.5994844e+00 5.1546097e-04]]\n" + ] + } + ], + "source": [ + "print(\"Error in w is:\", (true_w-w).asnumpy())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}