Skip to content

Commit

Permalink
added MultiOneVsRest classifier & testing suite
Browse files Browse the repository at this point in the history
  • Loading branch information
hugobowne authored and maniteja123 committed Mar 13, 2016
1 parent 25c931a commit ce11cbc
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 0 deletions.
1 change: 1 addition & 0 deletions sklearn/MultiOneVsRest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""MultiOneVsRestClassifier===========================This module includes several classes that extend base estimators to multi-target estimators. Most sklearn estimators use a response matrix to train a target functionwith a single output variable. I.e. typical estimators use the training set X to estimate a target function f(X) that predicts a single Y. The purpose of this class is to extend estimatorsto be able to estimate a series of target functions (f1,f2,f3...,fn)that are trained on a single X predictor matrix to predict a seriesof reponses (y1,y2,y3...,yn)."""#Author: Hugo Bowne-Anderson <[email protected]>#Author: Chris Rivera <[email protected]>#Author: Michael Williamson#License: BSD 3 clauseimport arrayimport numpy as npimport warningsimport scipy.sparse as spfrom sklearn.base import BaseEstimator, ClassifierMixinfrom sklearn.base import clone, is_classifierfrom sklearn.base import MetaEstimatorMixin, is_regressorfrom sklearn.preprocessing import LabelBinarizerfrom sklearn.metrics.pairwise import euclidean_distancesfrom sklearn.utils import check_random_statefrom sklearn.utils.validation import _num_samplesfrom sklearn.utils.validation import check_consistent_lengthfrom sklearn.utils.validation import check_is_fittedfrom sklearn.externals.joblib import Parallelfrom sklearn.externals.joblib import delayedfrom sklearn.multiclass import OneVsRestClassifierclass MultiOneVsRestClassifier(): """ Converts any classifer estimator into a multi-target classifier estimator. This class fits and predicts a series of one-versus-all models to response matrix Y, which has n_samples and p_target variables, on the predictor Matrix X with n_samples and m_feature variables. This allows for multiple target variable classifications. For each target variable (column in Y), a separate OneVsRestClassifier is fit. See the base OneVsRestClassifier Class in sklearn.multiclass for more details. Parameters ---------- estimator : estimator object An estimator object implementing `fit` & `predict_proba`. n_jobs : int, optional, default: 1 The number of jobs to use for the computation. If -1 all CPUs are used. If 1 is given, no parallel computing code is used at all, which is useful for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one are used. Note that parallel processing only occurs if there is multiple classes within each target variable. It does each target variable in y in series. Attributes __________ estimator: Sklearn estimator: The base estimator used to constructe the model. """ def __init__(self, estimator=None, n_jobs=1): self.estimator = estimator self.n_jobs = n_jobs def fit(self, X, y): """ Fit the model to data. Creates a seperate model for each Response column. Parameters ---------- X : (sparse) array-like, shape = [n_samples, n_features] Data. y : (sparse) array-like, shape = [n_samples, p_targets] Multi-class targets. An indicator matrix turns on multilabel classification. Returns ------- self """ # check to see that the data is numeric # check to see that the X and y have the same number of rows. # Calculate the number of classifiers self._num_y = y.shape[1] ## create a dictionary to hold the estimators. self.estimators_ ={} for i in range(self._num_y): # init a new classifer for each and fit it. estimator = clone(self.estimator) #make a fresh clone ovr = OneVsRestClassifier(estimator,self.n_jobs) self.estimators_[i] = ovr.fit(X,y[:, i]) return self def predict(self, X): """Predict multi-class multiple target variable using a model trained for each target variable. Parameters ---------- X : (sparse) array-like, shape = [n_samples, n_features] Data. Returns ------- y : dict of [sparse array-like], shape = {predictors: n_samples} or {predictors: [n_samples, n_classes], n_predictors}. Predicted multi-class targets across multiple predictors. Note: entirely separate models are generated for each predictor. """ # check to see if the fit has been performed check_is_fitted(self, 'estimators_') results = {} for label, model_ in self.estimators_.iteritems(): results[label] = model_.predict( X) return(results) def predict_proba(self, X): """Probability estimates. This returns prediction probabilites for each class for each label in the form of a dictionary. Parameters ---------- X : array-like, shape = [n_samples, n_features] Returns ------- prob_dict (dict) A dictionary containing n_label sparse arrays with shape = [n_samples, n_classes]. Each row in the array contains the the probability of the sample for each class in the model, where classes are ordered as they are in `self.classes_`. """ # check to see whether the fit has occured. check_is_fitted(self, 'estimators_') results ={} for label, model_ in self.estimators_.iteritems(): results[label] = model_.predict_proba(X) return(results) def score(self, X, Y): """"Returns the mean accuracy on the given test data and labels. Parameters ---------- X : array-like, shape = [n_samples, n_features] Y : (sparse) array-like, shape = [n_samples, p_targets] Returns ------- scores (np.array) Array of p_target floats of the mean accuracy of each estimator_.predict wrt. y. """ check_is_fitted(self, 'estimators_') # Score the results for each function results =[] for i in range(self._num_y): estimator = self.estimators_[i] results.append(estimator.score(X,Y[:,i])) return results def get_params(self): '''returns the parameters of the estimator.''' return self.estimator.get_params() def set_params(self, params): """sets the params for the estimator.""" self.estimator.set_params(params) def __repr__(self): return 'MultiOneVsRestClassifier( %s )' %self.estimator.__repr__() @property def multilabel_(self): """returns a vector of whether each classifer is a multilabel classifier in tuple for """ return [(label, model_.multilabel_) for label, model_ in self.estimators_.iteritems()] @property def classes_(self): return [(label, model_.label_binarizer_) for label, model_ in self.estimators_.iteritems()]
Expand Down
133 changes: 133 additions & 0 deletions sklearn/tests/test_mult_one_vs_rests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import numpy as np

from MultiOneVsRest import MultiOneVsRestClassifier
from sklearn.datasets import load_digits
from sklearn.base import clone
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression

# Import libraries for vaidation
import sklearn.utils.estimator_checks as ec
import sklearn.utils.validation as val
from sklearn import datasets

# import the shuffle
from sklearn.utils import shuffle
from sklearn.preprocessing import LabelBinarizer

# these are function for testing
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_almost_equal


def _create_data_set():
""" Creates a multi-target data set using the iris data set.
Returns:
________
X : (numpy matrix): The iris predictor data.
Y : (numpy array): A multi-target (150x3) array
generated from the
original response data
"""

# Import the data
iris = datasets.load_iris()
X = iris.data

# create a multiple targets by randomizing
#the shuffling and concatenating y.
y1 = iris.target
y2 = shuffle(y1, random_state = 1)
y3 = shuffle(y1, random_state = 2)

# concatenate the array and transpose
Y = np.vstack((y1,y2,y3)).T

return(X,Y)

def _set_up_multi_target_random_forest():
''' Set up the forest and multi-target forest'''

forest = RandomForestClassifier(n_estimators =100, random_state=1)
multi_target_forest = MultiOneVsRestClassifier(forest, n_jobs = -1)

return forest, multi_target_forest


def test_multi_target_init_with_random_forest():
''' test if multi_target initilizes correctly
as desired for random forest.
'''

forest, multi_target_forest = _set_up_multi_target_random_forest()

# check to see that the estimator type is correct
assert_equal(forest, multi_target_forest.estimator)
#check to that the number of jobs is correct
assert_equal(-1,multi_target_forest.n_jobs)

def test_multi_target_fit_and_predict_with_random_forest():
''' test the fit procedure with random forest and
assert that predictions work as expected.
'''

X,Y = _create_data_set()
forest, multi_target_forest = _set_up_multi_target_random_forest()

# train the multi_target_forest and also get the predictions.
multi_target_forest.fit(X,Y)
predictions = multi_target_forest.predict(X)
assert_equal(3,len(predictions))

# train the forest with each column
#and then assert that the predictions are equal
for i in range(3):
forest.fit(X,Y[:,i])
assert_equal(list(forest.predict(X)), list(predictions[i]))


def test_multi_target_fit_and_predict_probs_with_random_forest():
''' test the that the fit probabilites are as expected
up to one decimal point.
'''

# create the data set using the helper function
X,Y = _create_data_set()
forest, multi_target_forest = _set_up_multi_target_random_forest()

# train the multi_target_forest
multi_target_forest.fit(X,Y)
# train the forest with each column and then
#assert that the predictions are equal
for i in range(3):
forest_ = clone(forest) #create a clone with the same state
forest_.fit(X,Y[:,i])
assert_almost_equal(list(forest_.predict_proba(X)), list(multi_target_forest.predict_proba(X)[i]), decimal = 1)


def test_multi_target_score():
''' test the scoring function '''

# create the data set using the helper function
X,Y = _create_data_set()
forest, multi_target_forest = _set_up_multi_target_random_forest()

# train the multi_target_forest
multi_target_forest.fit(X,Y)

#score the multi_target_forest expect an array of floats
multi_score = multi_target_forest.score(X,Y)

# train the forest with each column
#and then assert that scores are similar.
for i in range(3):
score = forest.fit(X,Y[:,i]).score(X,Y[:,i])
assert_almost_equal(score, multi_score[i])


if __name__ == '__main__()':
test_multi_target_init_with_random_forest()
test_multi_target_fit_and_predict_with_random_forest()
test_multi_target_fit_and_predict_probs_with_random_forest()
test_multi_target_score()

0 comments on commit ce11cbc

Please sign in to comment.