Skip to content

Commit

Permalink
initial version of k-nearest neighbours
Browse files Browse the repository at this point in the history
  • Loading branch information
jovsa committed Jan 18, 2021
1 parent 232b0d5 commit 9936e0b
Showing 1 changed file with 153 additions and 18 deletions.
171 changes: 153 additions & 18 deletions notebooks/custering_and_neighbours.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
"from sklearn.datasets import make_blobs\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.cluster import KMeans"
"from sklearn.cluster import KMeans\n",
"from collections import defaultdict"
]
},
{
Expand Down Expand Up @@ -238,60 +239,194 @@
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X: (4, 2)\n",
"centroids:\n",
" [[ 3. 2. ]\n",
" [ 3. 2. ]\n",
" [ 0. 2. ]\n",
" [-3. 2.5]\n",
" [ 0. 2. ]]\n",
"point: [0. 2.] centroid: [3. 2.]\n",
"point: [0. 2.] centroid: [3. 2.]\n",
"point: [0. 2.] centroid: [0. 2.]\n",
"point: [0. 2.] centroid: [-3. 2.5]\n",
"point: [0. 2.] centroid: [0. 2.]\n",
"point: [3. 2.] centroid: [3. 2.]\n",
"point: [3. 2.] centroid: [3. 2.]\n",
"point: [3. 2.] centroid: [0. 2.]\n",
"point: [3. 2.] centroid: [-3. 2.5]\n",
"point: [3. 2.] centroid: [0. 2.]\n",
"point: [-3. 3.] centroid: [3. 2.]\n",
"point: [-3. 3.] centroid: [3. 2.]\n",
"point: [-3. 3.] centroid: [0. 2.]\n",
"point: [-3. 3.] centroid: [-3. 2.5]\n",
"point: [-3. 3.] centroid: [0. 2.]\n",
"point: [-3. 2.5] centroid: [3. 2.]\n",
"point: [-3. 2.5] centroid: [3. 2.]\n",
"point: [-3. 2.5] centroid: [0. 2.]\n",
"point: [-3. 2.5] centroid: [-3. 2.5]\n",
"point: [-3. 2.5] centroid: [0. 2.]\n",
"3 v: [array([0., 2.]), array([3., 2.])] [1.5 2. ]\n",
"0 v: [array([-3., 3.]), array([-3. , 2.5])] [-3. 2.75]\n",
"point: [0. 2.] centroid: [-3. 2.75]\n"
]
},
{
"ename": "IndexError",
"evalue": "index 0 is out of bounds for axis 1 with size 0",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-413-951d7dbebf1b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"X: {X_new.shape}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mkmeans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mKMeans\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_new\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m \u001b[0mkmeans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-413-951d7dbebf1b>\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, k, initilization)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mc_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcentroid\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcentroids\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"point: {point} centroid: {centroid}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0mdists\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc_idx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meuclid_dist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpoint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcentroid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;31m# print(dists)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mIndexError\u001b[0m: index 0 is out of bounds for axis 1 with size 0"
]
}
],
"source": [
"class KMeans:\n",
" def __init__(self, x):\n",
" self.x = x\n",
" self.seed = np.random.seed(10)\n",
" \n",
"\n",
" def init_centroids(self, k, initilization='random'):\n",
" return np.random.choice(self.x.ravel(), k)\n",
" \n",
" choices = np.random.choice(self.x.shape[0], k)\n",
" return self.x[choices]\n",
" \n",
" def euclid_dist(self, x, y):\n",
" return np.linalg.norm(x-y)\n",
" \n",
" def fit(self, k, initilization='random'):\n",
" if initilization != 'random':\n",
" raise NotImplementedError\n",
" # 0: calc centroids\n",
" centroids = self.init_centroids(k, initilization)\n",
" print(\"centroids:\\n\", centroids)\n",
" for _ in range(2):\n",
" \n",
" # 1: find distance to each centroid\n",
" \n",
" # 2: find new location\n",
" \n",
" # 3: decide to stop?\n",
" # 1: find distance to each centroid\n",
" dists = np.zeros((self.x.shape[0], k))\n",
" for x_idx, point in enumerate(self.x):\n",
" for c_idx, centroid in enumerate(centroids):\n",
" print(f\"point: {point} centroid: {centroid}\")\n",
" dists[x_idx, c_idx] = self.euclid_dist(point, centroid)\n",
"\n",
"# print(dists)\n",
" # 2: find new location\n",
" closest_centroids = np.argmax(dists, axis=1)\n",
" \n",
" chosen_points = defaultdict(list)\n",
" \n",
" for c_idx, chosen in enumerate(closest_centroids):\n",
" curr_point = self.x[c_idx] \n",
" curr_centroid = centroids[chosen]\n",
" chosen_points[chosen].append(curr_point)\n",
" \n",
" for k, v in chosen_points.items():\n",
" print(k, \"v:\", v, np.mean(v,axis=0))\n",
" centroids[k] = np.mean(v, axis=0)\n",
" # 3: decide to stop?\n",
" \n",
" print(centroids)\n",
" print(centroids.shape)\n",
" \n",
" def predict(self, y):\n",
" pass\n",
" "
"\n",
"print(f\"X: {X_new.shape}\")\n",
"kmeans = KMeans(X_new)\n",
"kmeans.fit(k=5)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References\n",
"* [Hands on ML](https://github.com/ageron/handson-ml2)\n",
"* [Sklearn Nearest Neighbors (supervised)](https://scikit-learn.org/stable/modules/neighbors.html)\n",
"* [Sklearn Clustering (unsupervised)](https://scikit-learn.org/stable/modules/clustering.html#clustering)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([[0, 1, 2],\n",
" [3, 4, 5]])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"kmeans = KMeans(X)\n",
"kmeans.fit(k=5)"
"np_array_2d = np.arange(0, 6).reshape([2,3])\n",
"np_array_2d"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([3, 5, 7])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"## References\n",
"* [Hands on ML](https://github.com/ageron/handson-ml2)\n",
"* [Sklearn Nearest Neighbors (supervised)](https://scikit-learn.org/stable/modules/neighbors.html)\n",
"* [Sklearn Clustering (unsupervised)](https://scikit-learn.org/stable/modules/clustering.html#clustering)"
"np.sum(np_array_2d, axis=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0 3]\n",
"[1 4]\n",
"[2 5]\n"
]
}
],
"source": [
"print(np_array_2d[:,0])\n",
"print(np_array_2d[:,1])\n",
"print(np_array_2d[:,2])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"x = np.asarray([1,2,1,1,3,4])\n"
]
},
{
"cell_type": "code",
Expand Down

0 comments on commit 9936e0b

Please sign in to comment.