Skip to content

Commit

Permalink
ENH: new CategoricalEncoder class (scikit-learn#9151)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche authored and jnothman committed Nov 21, 2017
1 parent bbdcd70 commit a2ebb8c
Show file tree
Hide file tree
Showing 12 changed files with 744 additions and 46 deletions.
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,7 @@ Model validation
preprocessing.MinMaxScaler
preprocessing.Normalizer
preprocessing.OneHotEncoder
preprocessing.CategoricalEncoder
preprocessing.PolynomialFeatures
preprocessing.QuantileTransformer
preprocessing.RobustScaler
Expand Down
112 changes: 76 additions & 36 deletions doc/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -455,47 +455,87 @@ Such features can be efficiently coded as integers, for instance
``[0, 1, 3]`` while ``["female", "from Asia", "uses Chrome"]`` would be
``[1, 2, 1]``.

Such integer representation can not be used directly with scikit-learn estimators, as these
expect continuous input, and would interpret the categories as being ordered, which is often
not desired (i.e. the set of browsers was ordered arbitrarily).

One possibility to convert categorical features to features that can be used
with scikit-learn estimators is to use a one-of-K or one-hot encoding, which is
implemented in :class:`OneHotEncoder`. This estimator transforms each
categorical feature with ``m`` possible values into ``m`` binary features, with
only one active.
To convert categorical features to such integer codes, we can use the
:class:`CategoricalEncoder`. When specifying that we want to perform an
ordinal encoding, the estimator transforms each categorical feature to one
new feature of integers (0 to n_categories - 1)::

>>> enc = preprocessing.CategoricalEncoder(encoding='ordinal')
>>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']]
>>> enc.fit(X) # doctest: +ELLIPSIS
CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>,
encoding='ordinal', handle_unknown='error')
>>> enc.transform([['female', 'from US', 'uses Safari']])
array([[ 0., 1., 1.]])

Such integer representation can, however, not be used directly with all
scikit-learn estimators, as these expect continuous input, and would interpret
the categories as being ordered, which is often not desired (i.e. the set of
browsers was ordered arbitrarily).

Another possibility to convert categorical features to features that can be used
with scikit-learn estimators is to use a one-of-K, also known as one-hot or
dummy encoding.
This type of encoding is the default behaviour of the :class:`CategoricalEncoder`.
The :class:`CategoricalEncoder` then transforms each categorical feature with
``n_categories`` possible values into ``n_categories`` binary features, with
one of them 1, and all others 0.

Continuing the example above::

>>> enc = preprocessing.OneHotEncoder()
>>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], [1, 0, 2]]) # doctest: +ELLIPSIS
OneHotEncoder(categorical_features='all', dtype=<... 'numpy.float64'>,
handle_unknown='error', n_values='auto', sparse=True)
>>> enc.transform([[0, 1, 3]]).toarray()
array([[ 1., 0., 0., 1., 0., 0., 0., 0., 1.]])

By default, how many values each feature can take is inferred automatically from the dataset.
It is possible to specify this explicitly using the parameter ``n_values``.
There are two genders, three possible continents and four web browsers in our
dataset.
Then we fit the estimator, and transform a data point.
In the result, the first two numbers encode the gender, the next set of three
numbers the continent and the last four the web browser.

Note that, if there is a possibility that the training data might have missing categorical
features, one has to explicitly set ``n_values``. For example,

>>> enc = preprocessing.OneHotEncoder(n_values=[2, 3, 4])
>>> # Note that there are missing categorical values for the 2nd and 3rd
>>> # features
>>> enc.fit([[1, 2, 3], [0, 2, 0]]) # doctest: +ELLIPSIS
OneHotEncoder(categorical_features='all', dtype=<... 'numpy.float64'>,
handle_unknown='error', n_values=[2, 3, 4], sparse=True)
>>> enc.transform([[1, 0, 0]]).toarray()
array([[ 0., 1., 1., 0., 0., 1., 0., 0., 0.]])
>>> enc = preprocessing.CategoricalEncoder()
>>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']]
>>> enc.fit(X) # doctest: +ELLIPSIS
CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>,
encoding='onehot', handle_unknown='error')
>>> enc.transform([['female', 'from US', 'uses Safari'],
... ['male', 'from Europe', 'uses Safari']]).toarray()
array([[ 1., 0., 0., 1., 0., 1.],
[ 0., 1., 1., 0., 0., 1.]])

By default, the values each feature can take is inferred automatically
from the dataset and can be found in the ``categories_`` attribute::

>>> enc.categories_
[array(['female', 'male'], dtype=object), array(['from Europe', 'from US'], dtype=object), array(['uses Firefox', 'uses Safari'], dtype=object)]

It is possible to specify this explicitly using the parameter ``categories``.
There are two genders, four possible continents and four web browsers in our
dataset::

>>> genders = ['female', 'male']
>>> locations = ['from Africa', 'from Asia', 'from Europe', 'from US']
>>> browsers = ['uses Chrome', 'uses Firefox', 'uses IE', 'uses Safari']
>>> enc = preprocessing.CategoricalEncoder(categories=[genders, locations, browsers])
>>> # Note that for there are missing categorical values for the 2nd and 3rd
>>> # feature
>>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']]
>>> enc.fit(X) # doctest: +ELLIPSIS
CategoricalEncoder(categories=[...],
dtype=<... 'numpy.float64'>, encoding='onehot',
handle_unknown='error')
>>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray()
array([[ 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.]])

If there is a possibility that the training data might have missing categorical
features, it can often be better to specify ``handle_unknown='ignore'`` instead
of setting the ``categories`` manually as above. When
``handle_unknown='ignore'`` is specified and unknown categories are encountered
during transform, no error will be raised but the resulting one-hot encoded
columns for this feature will be all zeros
(``handle_unknown='ignore'`` is only supported for one-hot encoding)::

>>> enc = preprocessing.CategoricalEncoder(handle_unknown='ignore')
>>> X = [['male', 'from US', 'uses Safari'], ['female', 'from Europe', 'uses Firefox']]
>>> enc.fit(X) # doctest: +ELLIPSIS
CategoricalEncoder(categories='auto', dtype=<... 'numpy.float64'>,
encoding='onehot', handle_unknown='ignore')
>>> enc.transform([['female', 'from Asia', 'uses Chrome']]).toarray()
array([[ 1., 0., 0., 0., 0., 0.]])


See :ref:`dict_feature_extraction` for categorical features that are represented
as a dict, not as integers.
as a dict, not as scalars.

.. _imputation:

Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new/_contributors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,5 @@
.. _Neeraj Gangwar: https://neerajgangwar.in

.. _Arthur Mensch: https://amensch.fr

.. _Joris Van den Bossche: https://github.com/jorisvandenbossche
11 changes: 11 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ Classifiers and regressors
Naive Bayes classifier described in Rennie et al. (2003).
By :user:`Michael A. Alcorn <airalcorn2>`.

Preprocessing

- Added :class:`preprocessing.CategoricalEncoder`, which allows to encode
categorical features as a numeric array, either using a one-hot (or
dummy) encoding scheme or by converting to ordinal integers.
Compared to the existing :class:`OneHotEncoder`, this new class handles
encoding of all feature types (also handles string-valued features) and
derives the categories based on the unique values in the features instead of
the maximum value in the features.
By :user:`Vighnesh Birodkar <vighneshbirodkar>` and `Joris Van den Bossche`_.

Model evaluation

- Added the :func:`metrics.balanced_accuracy_score` metric and a corresponding
Expand Down
6 changes: 3 additions & 3 deletions examples/ensemble/plot_feature_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import (RandomTreesEmbedding, RandomForestClassifier,
GradientBoostingClassifier)
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import CategoricalEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve
from sklearn.pipeline import make_pipeline
Expand Down Expand Up @@ -62,7 +62,7 @@

# Supervised transformation based on random forests
rf = RandomForestClassifier(max_depth=3, n_estimators=n_estimator)
rf_enc = OneHotEncoder()
rf_enc = CategoricalEncoder()
rf_lm = LogisticRegression()
rf.fit(X_train, y_train)
rf_enc.fit(rf.apply(X_train))
Expand All @@ -72,7 +72,7 @@
fpr_rf_lm, tpr_rf_lm, _ = roc_curve(y_test, y_pred_rf_lm)

grd = GradientBoostingClassifier(n_estimators=n_estimator)
grd_enc = OneHotEncoder()
grd_enc = CategoricalEncoder()
grd_lm = LogisticRegression()
grd.fit(X_train, y_train)
grd_enc.fit(grd.apply(X_train)[:, :, 0])
Expand Down
7 changes: 4 additions & 3 deletions sklearn/feature_extraction/dict_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class DictVectorizer(BaseEstimator, TransformerMixin):
However, note that this transformer will only do a binary one-hot encoding
when feature values are of type string. If categorical features are
represented as numeric values such as int, the DictVectorizer can be
followed by OneHotEncoder to complete binary one-hot encoding.
followed by :class:`sklearn.preprocessing.CategoricalEncoder` to complete
binary one-hot encoding.
Features that do not occur in a sample (mapping) will have a zero value
in the resulting array/matrix.
Expand Down Expand Up @@ -88,8 +89,8 @@ class DictVectorizer(BaseEstimator, TransformerMixin):
See also
--------
FeatureHasher : performs vectorization using only a hash function.
sklearn.preprocessing.OneHotEncoder : handles nominal/categorical features
encoded as columns of integers.
sklearn.preprocessing.CategoricalEncoder : handles nominal/categorical
features encoded as columns of arbitrary data types.
"""

def __init__(self, dtype=np.float64, separator="=", sparse=True,
Expand Down
2 changes: 2 additions & 0 deletions sklearn/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .data import minmax_scale
from .data import quantile_transform
from .data import OneHotEncoder
from .data import CategoricalEncoder

from .data import PolynomialFeatures

Expand All @@ -46,6 +47,7 @@
'QuantileTransformer',
'Normalizer',
'OneHotEncoder',
'CategoricalEncoder',
'RobustScaler',
'StandardScaler',
'add_dummy_feature',
Expand Down
Loading

0 comments on commit a2ebb8c

Please sign in to comment.