Skip to content

Commit

Permalink
[MRG+2] LOF algorithm (Anomaly Detection) (scikit-learn#5279)
Browse files Browse the repository at this point in the history
* LOF algorithm

add tests and example

fix DepreciationWarning by reshape(1,-1) one-sample data

LOF with inheritance

lof and lof2 return same score

fix bugs

fix bugs

optimized and cosmit

rm lof2

cosmit

rm MixinLOF + fit_predict

fix travis - optimize pairwise_distance like in KNeighborsMixin.kneighbors

add comparison example + doc

LOF -> LocalOutlierFactor
cosmit

change LOF API:
-fit(X).predict() and fit(X).decision_function() do prediction on X without
 considering samples as their own neighbors (ie without considering X as a
 new dataset as does fit(X).predict(X))
-rm fit_predict() method
-add a contamination parameter st predict returns a binary value like other
 anomaly detection algos

cosmit

doc + debug example

correction doc

pass on doc + examples

pep8 + fix warnings

first attempt at fixing API issues

minor changes

takes into account tguillemot advice

-remove pairwise_distance calculation as to heavy in memory
-add benchmarks

cosmit

minor changes + deals with duplicates

fix depreciation warnings

* factorize the two for loops

* take into account @albertthomas88 review and cosmit

* fix doc

* alex review + rebase

* make predict private add outlier_factor_ attribute and update tests

* make fit_predict take y argument

* fix benchmarks file

* update examples

* make decision_function public (rm X=None default)

* fix travis

* take into account tguillemot review + remove useless k_distance function

* fix broken links :meth:`kneighbors`

* cosmit

* whatsnew

* amueller review + remove _local_outlier_factor method

* add n_neighbors_ parameter the effective nb neighbors we use

* make decision_function private and negative_outlier_factor attribute
  • Loading branch information
ngoix authored and amueller committed Oct 25, 2016
1 parent 73d3f03 commit 788a458
Show file tree
Hide file tree
Showing 12 changed files with 710 additions and 33 deletions.
119 changes: 119 additions & 0 deletions benchmarks/bench_lof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
============================
LocalOutlierFactor benchmark
============================
A test of LocalOutlierFactor on classical anomaly detection datasets.
"""

from time import time
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import LocalOutlierFactor
from sklearn.metrics import roc_curve, auc
from sklearn.datasets import fetch_kddcup99, fetch_covtype, fetch_mldata
from sklearn.preprocessing import LabelBinarizer
from sklearn.utils import shuffle as sh

print(__doc__)

np.random.seed(2)

# datasets available: ['http', 'smtp', 'SA', 'SF', 'shuttle', 'forestcover']
datasets = ['shuttle']

novelty_detection = True # if False, training set polluted by outliers

for dataset_name in datasets:
# loading and vectorization
print('loading data')
if dataset_name in ['http', 'smtp', 'SA', 'SF']:
dataset = fetch_kddcup99(subset=dataset_name, shuffle=True,
percent10=False)
X = dataset.data
y = dataset.target

if dataset_name == 'shuttle':
dataset = fetch_mldata('shuttle')
X = dataset.data
y = dataset.target
X, y = sh(X, y)
# we remove data with label 4
# normal data are then those of class 1
s = (y != 4)
X = X[s, :]
y = y[s]
y = (y != 1).astype(int)

if dataset_name == 'forestcover':
dataset = fetch_covtype(shuffle=True)
X = dataset.data
y = dataset.target
# normal data are those with attribute 2
# abnormal those with attribute 4
s = (y == 2) + (y == 4)
X = X[s, :]
y = y[s]
y = (y != 2).astype(int)

print('vectorizing data')

if dataset_name == 'SF':
lb = LabelBinarizer()
lb.fit(X[:, 1])
x1 = lb.transform(X[:, 1])
X = np.c_[X[:, :1], x1, X[:, 2:]]
y = (y != 'normal.').astype(int)

if dataset_name == 'SA':
lb = LabelBinarizer()
lb.fit(X[:, 1])
x1 = lb.transform(X[:, 1])
lb.fit(X[:, 2])
x2 = lb.transform(X[:, 2])
lb.fit(X[:, 3])
x3 = lb.transform(X[:, 3])
X = np.c_[X[:, :1], x1, x2, x3, X[:, 4:]]
y = (y != 'normal.').astype(int)

if dataset_name == 'http' or dataset_name == 'smtp':
y = (y != 'normal.').astype(int)

n_samples, n_features = np.shape(X)
n_samples_train = n_samples // 2
n_samples_test = n_samples - n_samples_train

X = X.astype(float)
X_train = X[:n_samples_train, :]
X_test = X[n_samples_train:, :]
y_train = y[:n_samples_train]
y_test = y[n_samples_train:]

if novelty_detection:
X_train = X_train[y_train == 0]
y_train = y_train[y_train == 0]

print('LocalOutlierFactor processing...')
model = LocalOutlierFactor(n_neighbors=20)
tstart = time()
model.fit(X_train)
fit_time = time() - tstart
tstart = time()

scoring = -model.decision_function(X_test) # the lower, the more normal
predict_time = time() - tstart
fpr, tpr, thresholds = roc_curve(y_test, scoring)
AUC = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=1,
label=('ROC for %s (area = %0.3f, train-time: %0.2fs,'
'test-time: %0.2fs)' % (dataset_name, AUC, fit_time,
predict_time)))

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.show()
3 changes: 2 additions & 1 deletion doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,8 @@ See the :ref:`metrics` section of the user guide for further details.
neighbors.LSHForest
neighbors.DistanceMetric
neighbors.KernelDensity

neighbors.LocalOutlierFactor

.. autosummary::
:toctree: generated/
:template: function.rst
Expand Down
89 changes: 75 additions & 14 deletions doc/modules/outlier_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,76 @@ This strategy is illustrated below.

* See :ref:`sphx_glr_auto_examples_covariance_plot_outlier_detection.py` for a
comparison of :class:`ensemble.IsolationForest` with
:class:`neighbors.LocalOutlierFactor`,
:class:`svm.OneClassSVM` (tuned to perform like an outlier detection
method) and a covariance-based outlier detection with
:class:`covariance.MinCovDet`.
:class:`covariance.EllipticEnvelope`.

.. topic:: References:

.. [LTZ2008] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation forest."
Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on.
One-class SVM versus Elliptic Envelope versus Isolation Forest
--------------------------------------------------------------
Local Outlier Factor
--------------------
Another efficient way to perform outlier detection on moderately high dimensional
datasets is to use the Local Outlier Factor (LOF) algorithm.

The :class:`neighbors.LocalOutlierFactor` (LOF) algorithm computes a score
(called local outlier factor) reflecting the degree of abnormality of the
observations.
It measures the local density deviation of a given data point with respect to
its neighbors. The idea is to detect the samples that have a substantially
lower density than their neighbors.

In practice the local density is obtained from the k-nearest neighbors.
The LOF score of an observation is equal to the ratio of the
average local density of his k-nearest neighbors, and its own local density:
a normal instance is expected to have a local density similar to that of its
neighbors, while abnormal data are expected to have much smaller local density.

The number k of neighbors considered, (alias parameter n_neighbors) is typically
chosen 1) greater than the minimum number of objects a cluster has to contain,
so that other objects can be local outliers relative to this cluster, and 2)
smaller than the maximum number of close by objects that can potentially be
local outliers.
In practice, such informations are generally not available, and taking
n_neighbors=20 appears to work well in general.
When the proportion of outliers is high (i.e. greater than 10 \%, as in the
example below), n_neighbors should be greater (n_neighbors=35 in the example
below).

The strength of the LOF algorithm is that it takes both local and global
properties of datasets into consideration: it can perform well even in datasets
where abnormal samples have different underlying densities.
The question is not, how isolated the sample is, but how isolated it is
with respect to the surrounding neighborhood.

This strategy is illustrated below.

.. figure:: ../auto_examples/neighbors/images/sphx_glr_plot_lof_001.png
:target: ../auto_examples/neighbors/plot_lof.html
:align: center
:scale: 75%

.. topic:: Examples:

* See :ref:`sphx_glr_auto_example_neighbors_plot_lof.py` for
an illustration of the use of :class:`neighbors.LocalOutlierFactor`.

* See :ref:`sphx_glr_auto_example_covariance_plot_outlier_detection.py` for a
comparison with other anomaly detection methods.

.. topic:: References:

.. [BKNS2000] Breunig, Kriegel, Ng, and Sander (2000)
`LOF: identifying density-based local outliers.
<http:https://www.dbs.ifi.lmu.de/Publikationen/Papers/LOF.pdf>`_
Proc. ACM SIGMOD
One-class SVM versus Elliptic Envelope versus Isolation Forest versus LOF
-------------------------------------------------------------------------

Strictly-speaking, the One-class SVM is not an outlier-detection method,
but a novelty-detection method: its training set should not be
Expand All @@ -188,7 +246,8 @@ results in these situations.
The examples below illustrate how the performance of the
:class:`covariance.EllipticEnvelope` degrades as the data is less and
less unimodal. The :class:`svm.OneClassSVM` works better on data with
multiple modes and :class:`ensemble.IsolationForest` performs well in every cases.
multiple modes and :class:`ensemble.IsolationForest` and
:class:`neighbors.LocalOutlierFactor` perform well in every cases.

.. |outlier1| image:: ../auto_examples/covariance/images/sphx_glr_plot_outlier_detection_001.png
:target: ../auto_examples/covariance/plot_outlier_detection.html
Expand All @@ -202,7 +261,7 @@ multiple modes and :class:`ensemble.IsolationForest` performs well in every case
:target: ../auto_examples/covariance/plot_outlier_detection.html
:scale: 50%

.. list-table:: **Comparing One-class SVM approach, and elliptic envelope**
.. list-table:: **Comparing One-class SVM, Isolation Forest, LOF, and Elliptic Envelope**
:widths: 40 60

*
Expand All @@ -213,31 +272,33 @@ multiple modes and :class:`ensemble.IsolationForest` performs well in every case
opposite, the decision rule based on fitting an
:class:`covariance.EllipticEnvelope` learns an ellipse, which
fits well the inlier distribution. The :class:`ensemble.IsolationForest`
performs as well.
- |outlier1|
and :class:`neighbors.LocalOutlierFactor` perform as well.
- |outlier1|

*
- As the inlier distribution becomes bimodal, the
:class:`covariance.EllipticEnvelope` does not fit well the
inliers. However, we can see that both :class:`ensemble.IsolationForest`
and :class:`svm.OneClassSVM` have difficulties to detect the two modes,
inliers. However, we can see that :class:`ensemble.IsolationForest`,
:class:`svm.OneClassSVM` and :class:`neighbors.LocalOutlierFactor`
have difficulties to detect the two modes,
and that the :class:`svm.OneClassSVM`
tends to overfit: because it has not model of inliers, it
tends to overfit: because it has no model of inliers, it
interprets a region where, by chance some outliers are
clustered, as inliers.
- |outlier2|

*
- If the inlier distribution is strongly non Gaussian, the
:class:`svm.OneClassSVM` is able to recover a reasonable
approximation as well as :class:`ensemble.IsolationForest`,
approximation as well as :class:`ensemble.IsolationForest`
and :class:`neighbors.LocalOutlierFactor`,
whereas the :class:`covariance.EllipticEnvelope` completely fails.
- |outlier3|

.. topic:: Examples:

* See :ref:`sphx_glr_auto_examples_covariance_plot_outlier_detection.py` for a
comparison of the :class:`svm.OneClassSVM` (tuned to perform like
an outlier detection method), the :class:`ensemble.IsolationForest`
and a covariance-based outlier
detection with :class:`covariance.MinCovDet`.
an outlier detection method), the :class:`ensemble.IsolationForest`,
the :class:`neighbors.LocalOutlierFactor`
and a covariance-based outlier detection :class:`covariance.EllipticEnvelope`.
5 changes: 4 additions & 1 deletion doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Changelog
New features
............

- Added the :class:`neighbors.LocalOutlierFactor` class for anomaly detection based
on nearest neighbors. By `Nicolas Goix`_ and `Alexandre Gramfort`_.

Enhancements
............

Expand Down Expand Up @@ -4740,7 +4743,7 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.

.. _Eric Martin: http:https://www.ericmart.in

.. _Nicolas Goix: https://webperso.telecom-paristech.fr/front/frontoffice.php?SP_ID=241
.. _Nicolas Goix: https://perso.telecom-paristech.fr/~goix/

.. _Cory Lorenz: https://github.com/clorenz7

Expand Down
41 changes: 29 additions & 12 deletions examples/covariance/plot_outlier_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
hence more adapted to large-dimensional settings, even if it performs
quite well in the examples below.
- using the Local Outlier Factor to measure the local deviation of a given
data point with respect to its neighbors by comparing their local density.
The ground truth about inliers and outliers is given by the points colors
while the orange-filled area indicates which points are reported as inliers
by each method.
Expand All @@ -27,7 +30,6 @@
threshold on the decision_function to separate out the corresponding
fraction.
"""
print(__doc__)

import numpy as np
from scipy import stats
Expand All @@ -37,6 +39,9 @@
from sklearn import svm
from sklearn.covariance import EllipticEnvelope
from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor

print(__doc__)

rng = np.random.RandomState(42)

Expand All @@ -52,10 +57,13 @@
"Robust covariance": EllipticEnvelope(contamination=outliers_fraction),
"Isolation Forest": IsolationForest(max_samples=n_samples,
contamination=outliers_fraction,
random_state=rng)}
random_state=rng),
"Local Outlier Factor": LocalOutlierFactor(
n_neighbors=35,
contamination=outliers_fraction)}

# Compare given classifiers under given settings
xx, yy = np.meshgrid(np.linspace(-7, 7, 500), np.linspace(-7, 7, 500))
xx, yy = np.meshgrid(np.linspace(-7, 7, 100), np.linspace(-7, 7, 100))
n_inliers = int((1. - outliers_fraction) * n_samples)
n_outliers = int(outliers_fraction * n_samples)
ground_truth = np.ones(n_samples, dtype=int)
Expand All @@ -72,19 +80,27 @@
X = np.r_[X, np.random.uniform(low=-6, high=6, size=(n_outliers, 2))]

# Fit the model
plt.figure(figsize=(10.8, 3.6))
plt.figure(figsize=(9, 7))
for i, (clf_name, clf) in enumerate(classifiers.items()):
# fit the data and tag outliers
clf.fit(X)
scores_pred = clf.decision_function(X)
if clf_name == "Local Outlier Factor":
y_pred = clf.fit_predict(X)
scores_pred = clf.negative_outlier_factor_
else:
clf.fit(X)
scores_pred = clf.decision_function(X)
y_pred = clf.predict(X)
threshold = stats.scoreatpercentile(scores_pred,
100 * outliers_fraction)
y_pred = clf.predict(X)
n_errors = (y_pred != ground_truth).sum()
# plot the levels lines and the points
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
if clf_name == "Local Outlier Factor":
# decision_function is private for LOF
Z = clf._decision_function(np.c_[xx.ravel(), yy.ravel()])
else:
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
subplot = plt.subplot(1, 3, i + 1)
subplot = plt.subplot(2, 2, i + 1)
subplot.contourf(xx, yy, Z, levels=np.linspace(Z.min(), threshold, 7),
cmap=plt.cm.Blues_r)
a = subplot.contour(xx, yy, Z, levels=[threshold],
Expand All @@ -97,11 +113,12 @@
subplot.legend(
[a.collections[0], b, c],
['learned decision function', 'true inliers', 'true outliers'],
prop=matplotlib.font_manager.FontProperties(size=11),
prop=matplotlib.font_manager.FontProperties(size=10),
loc='lower right')
subplot.set_title("%d. %s (errors: %d)" % (i + 1, clf_name, n_errors))
subplot.set_xlabel("%d. %s (errors: %d)" % (i + 1, clf_name, n_errors))
subplot.set_xlim((-7, 7))
subplot.set_ylim((-7, 7))
plt.subplots_adjust(0.04, 0.1, 0.96, 0.92, 0.1, 0.26)
plt.subplots_adjust(0.04, 0.1, 0.96, 0.94, 0.1, 0.26)
plt.suptitle("Outlier detection")

plt.show()
Loading

0 comments on commit 788a458

Please sign in to comment.