Skip to content

Commit

Permalink
completed sparsity sections; added benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
oddskool committed Nov 4, 2013
1 parent 501230d commit 1bdb7b2
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 77 deletions.
103 changes: 103 additions & 0 deletions benchmarks/bench_sparsify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
Benchmark SGD prediction time with dense/sparse coefficients.
Invoke with
-----------
$ kernprof.py -l sparsity_benchmark.py
$ python -m line_profiler sparsity_benchmark.py.lprof
Typical output
--------------
input data sparsity: 0.050000
true coef sparsity: 0.000100
test data sparsity: 0.027400
model sparsity: 0.000024
r^2 on test data (dense model) : 0.233651
r^2 on test data (sparse model) : 0.233651
Wrote profile results to sparsity_benchmark.py.lprof
Timer unit: 1e-06 s
File: sparsity_benchmark.py
Function: benchmark_dense_predict at line 51
Total time: 0.532979 s
Line # Hits Time Per Hit % Time Line Contents
==============================================================
51 @profile
52 def benchmark_dense_predict():
53 301 640 2.1 0.1 for _ in range(300):
54 300 532339 1774.5 99.9 clf.predict(X_test)
File: sparsity_benchmark.py
Function: benchmark_sparse_predict at line 56
Total time: 0.39274 s
Line # Hits Time Per Hit % Time Line Contents
==============================================================
56 @profile
57 def benchmark_sparse_predict():
58 1 10854 10854.0 2.8 X_test_sparse = csr_matrix(X_test)
59 301 477 1.6 0.1 for _ in range(300):
60 300 381409 1271.4 97.1 clf.predict(X_test_sparse)
"""

from scipy.sparse.csr import csr_matrix
import numpy as np
from sklearn.linear_model.stochastic_gradient import SGDRegressor
from sklearn.metrics import r2_score

np.random.seed(42)


def sparsity_ratio(X):
return np.count_nonzero(X) / float(n_samples * n_features)

n_samples, n_features = 5000, 300
X = np.random.randn(n_samples, n_features)
inds = np.arange(n_samples)
np.random.shuffle(inds)
X[inds[n_features/1.2:]] = 0 # sparsify input
print("input data sparsity: %f" % sparsity_ratio(X))
coef = 3 * np.random.randn(n_features)
inds = np.arange(n_features)
np.random.shuffle(inds)
coef[inds[n_features/2:]] = 0 # sparsify coef
print("true coef sparsity: %f" % sparsity_ratio(coef))
y = np.dot(X, coef)

# add noise
y += 0.01 * np.random.normal((n_samples,))

# Split data in train set and test set
n_samples = X.shape[0]
X_train, y_train = X[:n_samples / 2], y[:n_samples / 2]
X_test, y_test = X[n_samples / 2:], y[n_samples / 2:]
print("test data sparsity: %f" % sparsity_ratio(X_test))

###############################################################################
clf = SGDRegressor(penalty='l1', alpha=.2, fit_intercept=True, n_iter=2000)
clf.fit(X_train, y_train)
print("model sparsity: %f" % sparsity_ratio(clf.coef_))

@profile
def benchmark_dense_predict():
for _ in range(300):
clf.predict(X_test)

@profile
def benchmark_sparse_predict():
X_test_sparse = csr_matrix(X_test)
for _ in range(300):
clf.predict(X_test_sparse)

def score(y_test, y_pred, case):
r2 = r2_score(y_test, y_pred)
print("r^2 on test data (%s) : %f" % (case, r2))

score(y_test, clf.predict(X_test), 'dense model')
benchmark_dense_predict()
clf.sparsify()
score(y_test, clf.predict(X_test), 'sparse model')
benchmark_sparse_predict()
72 changes: 48 additions & 24 deletions doc/modules/performance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,35 @@ memory footprint and estimator).
Influence of the Input Data Representation
------------------------------------------

tbd (CSR vs dense vs ...)
Numpy / Scipy support sparse matrix formats which are optimized for storing
sparse data. The main feature of sparse formats is that you don't store zeros
so if your data is sparse then you use much less memory. A non-zero value in
a sparse (`CSR or CSC<http:https://docs.scipy.org/doc/scipy/reference/sparse.html>`_)
representation will only take on average one 32bit integer position + the 64
bit floating point value. Using sparse input on a dense (or sparse) linear
model can speedup prediction prediction by quite a bit as only the non zero
valued features impact the dot product and thus the model predictions. Hence
if you have 100 non zeros in 1e6 dimensional space, you only need 100 multiply
+ add operation instead of 1e6.

Note that dense / dense operations benefit from both BLAS-provided SSE
vectorized operations and multithreading and lower CPU cache miss rates. Sparse
dot product is more hit or miss and does not leverage the optimized BLAS
benefit. So the sparsity should typically be quite high (10% non-zeros max,
to be checked depending on the hardware) for the sparse input representation
to be faster that the dense input representation on a machine with many CPU and
an optimized BLAS implementation.

Here is a sample code to test the sparsity of your input:

>>> import numpy as np
>>> def sparsity_ratio(X):
>>> return np.count_nonzero(X) / float(X.shape[0] * X.shape[1])
>>> print("input sparsity ratio:", sparsity_ratio(X))

Now if you want to try to leverage sparsity for your input data you should
either build your input matrix in the CSR or CSC or call the ``to_csr()``
method or the ``csr_matrix()`` helper function from Scipy.

Prediction Throughput
=====================
Expand Down Expand Up @@ -172,28 +200,24 @@ Model Compression

Model compression in scikit-learn only concerns linear models for the moment.
In this context it means that we want to control the model sparsity (i.e. the
number of non-zero coordinates in the model vectors). Numpy / Scipy support
sparse matrix formats which are optimized for storing sparse data. The main
feature of sparse formats is that you don't store zeros so if your data is
sparse then you use much less memory. A non-zero value in a sparse (CSR
or CSC) representation will only take on average one 32bit integer position +
the 64 bit floating point value. Using sparse input on a dense (or sparse)
linear model can speedup prediction prediction by quite a bit as only the non
zero valued features impact the dot product and thus the model predictions.
Hence if you have 100 non zeros in 1e6 dimensional space, you only need 100
multiply + add operation instead of 1e6.

You can do micro benchmarks of safe_sparse_dot(data, coef.T) where data has
shape (n_samples, n_features) and coef has shape (n_classes,
n_features) for various level of sparsity and representations of data and
coef to get a feeling on the impact of the performance prediction.

Note that dense x dense operations benefit from both BLAS-provided SSE
vectorized operations and multithreading and lower CPU cache misrates. Sparse
dot product is more hit or miss and does not leverage the optimized BLAS
benefit. So the sparsity should typically be quite high (10% non-zeros max,
to be checked depending on the hardware) for the sparse input representation
to be faster that the dense input representation on a machine with many CPU and
an optimized BLAS implementation.
number of non-zero coordinates in the model vectors). It is generally a good
idea to combine model sparsity with sparse input data representation.

Here is a sample code that illustrates the use of the ``sparsify()`` method:

>>> clf = SGDRegressor(penalty='l1')
>>> clf.fit(X_train, y_train)
>>> clf.sparsify()
>>> clf.predict(X_test)

A typical benchmark (:ref:`benchmarks_bench_sparsify.py`) on synthetic data
yields a >30% decrease in latency when both the model and input are sparsed
(with 0.000024 and 0.027400 non-zero coefficients ratio respectively).
Your mileage may vary depending on the sparsity and size of your data and
model.

Links
-----

- `scikit-learn developer performance documentation<http:https://scikit-learn.org/stable/developers/performance.html>`_
- `Scipy sparse matrix formats documentation<http:https://docs.scipy.org/doc/scipy/reference/sparse.html>`_
53 changes: 0 additions & 53 deletions doc/modules/sparsity_benchmark.py

This file was deleted.

0 comments on commit 1bdb7b2

Please sign in to comment.