Skip to content

Commit

Permalink
Merge pull request #151 from BethanyL/HW9
Browse files Browse the repository at this point in the history
Hw9
  • Loading branch information
jakevdp committed Dec 8, 2014
2 parents d9652e5 + 2141712 commit f12def6
Showing 1 changed file with 202 additions and 0 deletions.
202 changes: 202 additions & 0 deletions BethanyL/HW9.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
{
"metadata": {
"name": "",
"signature": "sha256:245663b949d0c983b529355a6b9cdae018a29591b12622ececb03ebb43996880"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"from __future__ import print_function"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from IPython import parallel\n",
"clients = parallel.Client()\n",
"clients.block = True # use synchronous computations\n",
"print(clients.ids)\n",
"dview = clients.direct_view()"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"[0, 1]\n"
]
}
],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from astroML.datasets import fetch_sdss_specgals\n",
"import numpy as np\n",
"data = fetch_sdss_specgals()\n",
"\n",
"# put magnitudes in a matrix\n",
"X = np.vstack([data['modelMag_%s' % f] for f in 'ugriz']).T\n",
"y = data['z']\n",
"\n",
"# down-sample the data for speed\n",
"X = X[::10]\n",
"y = y[::10]"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 3
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"%%px\n",
"import numpy as np\n",
"\n",
"import sklearn.ensemble, sklearn.cross_validation\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.cross_validation import cross_val_score, StratifiedKFold\n",
"\n",
"def rfreg(x):\n",
" # define the model with the given value of C\n",
" \n",
" [n,m]=x.shape\n",
" scores = np.zeros(n)\n",
" score = np.zeros(n)\n",
" for j in range(n):\n",
" model = RandomForestRegressor(n_estimators=x[j][0], max_depth=x[j][1])\n",
"\n",
" # compute the scores via cross-validation\n",
" print('ready to score!!!')\n",
" scores = cross_val_score(model, Xd, Yd, scoring='r2', cv=3)\n",
"\n",
" # print the mean of the cross-validation scores\n",
" score[j] = np.mean(scores)\n",
"\n",
"# print(score)\n",
" return score"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 20
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"#%px \n",
"def parallel_rfreg(dview, n_est, max_d):\n",
" dview.push(dict(Xd = X, Yd = y))\n",
" #dview.push(y)\n",
" \n",
" A, B =np.meshgrid(n_est, max_d) #product(n_est, max_d)\n",
" A = np.reshape(A,(A.size,1))\n",
" B = np.reshape(B,(B.size,1))\n",
"\n",
" params = np.hstack((A,B))\n",
" #print(params)\n",
" #A = product(n_est, max_d)\n",
" #%px import numpy as np\n",
" \n",
" dview.scatter('param',params)\n",
" #dview.scatter('param',param)\n",
" #dview.apply(rfreg,param)\n",
" dview.execute('sc=rfreg(param)')\n",
" sc = dview.gather('sc')\n",
" print(sc)\n",
" \n",
" return max(sc), params[np.argmax(sc)][:2]"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 16
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from itertools import product\n",
"n_est=np.arange(5, 15)\n",
"max_d=np.arange(1,5)\n",
"print(n_est)\n",
"print(max_d)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"[ 5 6 7 8 9 10 11 12 13 14]\n",
"[1 2 3 4]\n"
]
}
],
"prompt_number": 22
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"parallel_rfreg(dview,n_est,max_d)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"[ 0.33465985 0.33217848 0.32786725 0.33632745 0.33604494 0.33595623\n",
" 0.33868149 0.33468641 0.33519086 0.33636052 0.49129664 0.49420653\n",
" 0.49188764 0.4891963 0.49128236 0.49017637 0.48764369 0.49241277\n",
" 0.49448054 0.49160429 0.55641043 0.55708224 0.55813383 0.55692452\n",
" 0.55086971 0.55469108 0.55496256 0.5517019 0.55205004 0.55468357\n",
" 0.60631688 0.61041368 0.60785728 0.6139131 0.6120328 0.61333911\n",
" 0.61093431 0.61142202 0.61327971 0.61281658]\n"
]
},
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 23,
"text": [
"(0.61391310443364611, array([8, 4]))"
]
}
],
"prompt_number": 23
},
{
"cell_type": "code",
"collapsed": false,
"input": [],
"language": "python",
"metadata": {},
"outputs": []
}
],
"metadata": {}
}
]
}

0 comments on commit f12def6

Please sign in to comment.