Skip to content
This repository has been archived by the owner on Feb 23, 2021. It is now read-only.

Commit

Permalink
HULK SMASH EXAMPLE WITH AUTOPEP8, FOR GOOD MEASURE
Browse files Browse the repository at this point in the history
  • Loading branch information
jackmaney committed Jan 14, 2014
1 parent 2737fb0 commit d7a2d57
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions examples/three_clusters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pandas import DataFrame,Series
from pandas import DataFrame, Series
import pandas as pd
import numpy as np
import sys
Expand All @@ -8,24 +8,29 @@

from k_means_plus_plus import *

np.random.seed(1234) #For reproducibility
np.random.seed(1234) # For reproducibility

# We create a data set with three sets of 500 points each chosen from a normal distrubution with a standard deviation of 10.
# The means for the distributions from which we sample are (25,45), (-30,5), and (5,-20)
data = DataFrame({'x':10*np.random.randn(500) + 25,'y':10*np.random.randn(500) + 45},columns=list('xy'))
data = data.append(DataFrame({'x':10*np.random.randn(500) - 30,'y':10*np.random.randn(500) + 5},columns=list('xy')))
data = data.append(DataFrame({'x':10*np.random.randn(500) + 5,'y':10*np.random.randn(500) - 20},columns=list('xy')))
# The means for the distributions from which we sample are:
# (25,45), (-30,5), and (5,-20)
data = DataFrame({'x': 10 * np.random.randn(500) + 25, 'y':
10 * np.random.randn(500) + 45}, columns=list('xy'))
data = data.append(DataFrame(
{'x': 10 * np.random.randn(500) - 30, 'y': 10 * np.random.randn(500) + 5}, columns=list('xy')))
data = data.append(DataFrame(
{'x': 10 * np.random.randn(500) + 5, 'y': 10 * np.random.randn(500) - 20}, columns=list('xy')))

# Grab a scatterplot
import matplotlib.pyplot as plt
plt.scatter(data['x'],data['y'],s=5)
plt.scatter(data['x'], data['y'], s=5)
plt.savefig("three_clusters_scatterplot.png")

# Cluster
kmpp = KMeansPlusPlus(data,3)
kmpp = KMeansPlusPlus(data, 3)
kmpp.cluster()

# Get a scatterplot that's color-coded by cluster
colors = ["red" if x == 0 else "blue" if x == 1 else "green" for x in kmpp.clusters]
plt.scatter(data['x'],data['y'],s=5,c=colors)
plt.savefig("three_clusters_clusters.png")
colors = [
"red" if x == 0 else "blue" if x == 1 else "green" for x in kmpp.clusters]
plt.scatter(data['x'], data['y'], s=5, c=colors)
plt.savefig("three_clusters_clusters.png")

0 comments on commit d7a2d57

Please sign in to comment.