Skip to content

Commit

Permalink
Add logsumexp and comb to utils.fixes (scikit-learn#9046)
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyak authored and agramfort committed Jun 10, 2017
1 parent 93871e2 commit 95aa295
Show file tree
Hide file tree
Showing 15 changed files with 25 additions and 21 deletions.
2 changes: 1 addition & 1 deletion sklearn/decomposition/online_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@

import numpy as np
import scipy.sparse as sp
from scipy.misc import logsumexp
from scipy.special import gammaln
import warnings

from ..base import BaseEstimator, TransformerMixin
from ..utils import (check_random_state, check_array,
gen_batches, gen_even_slices, _get_n_jobs)
from ..utils.fixes import logsumexp
from ..utils.validation import check_non_negative
from ..externals.joblib import Parallel, delayed
from ..externals.six.moves import xrange
Expand Down
3 changes: 1 addition & 2 deletions sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from abc import abstractmethod

from .base import BaseEnsemble
from ..base import BaseEstimator
from ..base import ClassifierMixin
from ..base import RegressorMixin
from ..externals import six
Expand All @@ -40,7 +39,6 @@
import numpy as np

from scipy import stats
from scipy.misc import logsumexp
from scipy.sparse import csc_matrix
from scipy.sparse import csr_matrix
from scipy.sparse import issparse
Expand All @@ -57,6 +55,7 @@
from ..utils import column_or_1d
from ..utils import check_consistent_length
from ..utils import deprecated
from ..utils.fixes import logsumexp
from ..utils.stats import _weighted_percentile
from ..utils.validation import check_is_fitted
from ..utils.multiclass import check_classification_targets
Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from itertools import product

import numpy as np
from scipy.misc import comb
from scipy.sparse import csr_matrix
from scipy.sparse import csc_matrix
from scipy.sparse import coo_matrix
Expand Down Expand Up @@ -42,6 +41,7 @@
from sklearn.model_selection import GridSearchCV
from sklearn.svm import LinearSVC
from sklearn.utils.validation import check_random_state
from sklearn.utils.fixes import comb

from sklearn.tree.tree import SPARSE_SPLITTERS

Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import numpy as np
from scipy import optimize, sparse
from scipy.misc import logsumexp
from scipy.special import expit

from .base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator
Expand All @@ -27,6 +26,7 @@
from ..utils.extmath import (log_logistic, safe_sparse_dot, softmax,
squared_norm)
from ..utils.extmath import row_norms
from ..utils.fixes import logsumexp
from ..utils.optimize import newton_cg
from ..utils.validation import check_X_y
from ..exceptions import NotFittedError
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/tests/test_sag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import math
import numpy as np
import scipy.sparse as sp
from scipy.misc import logsumexp

from sklearn.linear_model.sag import get_auto_step_size
from sklearn.linear_model.sag_fast import _multinomial_grad_loss_all_samples
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.linear_model.base import make_dataset
from sklearn.linear_model.logistic import _multinomial_loss_grad

from sklearn.utils.fixes import logsumexp
from sklearn.utils.extmath import row_norms
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_array_almost_equal
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/cluster/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from math import log

import numpy as np
from scipy.misc import comb
from scipy import sparse as sp

from .expected_mutual_info_fast import expected_mutual_information
from ...utils.validation import check_array
from ...utils.fixes import comb


def comb2(n):
Expand Down
2 changes: 1 addition & 1 deletion sklearn/mixture/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from time import time

import numpy as np
from scipy.misc import logsumexp

from .. import cluster
from ..base import BaseEstimator
from ..base import DensityMixin
from ..externals import six
from ..exceptions import ConvergenceWarning
from ..utils import check_array, check_random_state
from ..utils.fixes import logsumexp


def _check_shape(param, param_shape, name):
Expand Down
2 changes: 1 addition & 1 deletion sklearn/mixture/dpgmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from scipy.special import digamma as _digamma, gammaln as _gammaln
from scipy import linalg
from scipy.linalg import pinvh
from scipy.misc import logsumexp
from scipy.spatial.distance import cdist

from ..externals.six.moves import xrange
from ..utils import check_random_state, check_array, deprecated
from ..utils.fixes import logsumexp
from ..utils.extmath import squared_norm, stable_cumsum
from ..utils.validation import check_is_fitted
from .. import cluster
Expand Down
2 changes: 1 addition & 1 deletion sklearn/mixture/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import numpy as np
from scipy import linalg
from scipy.misc import logsumexp

from ..base import BaseEstimator
from ..utils import check_random_state, check_array, deprecated
from ..utils.fixes import logsumexp
from ..utils.validation import check_is_fitted
from .. import cluster

Expand Down
3 changes: 1 addition & 2 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@

import numpy as np

from scipy.misc import comb
from ..utils import indexable, check_random_state, safe_indexing
from ..utils.validation import _num_samples, column_or_1d
from ..utils.validation import check_array
from ..utils.multiclass import type_of_target
from ..externals.six import with_metaclass
from ..externals.six.moves import zip
from ..utils.fixes import signature
from ..utils.fixes import signature, comb
from ..base import _pprint

__all__ = ['BaseCrossValidator',
Expand Down
3 changes: 2 additions & 1 deletion sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix
from scipy import stats
from scipy.misc import comb
from itertools import combinations
from itertools import combinations_with_replacement

Expand Down Expand Up @@ -57,6 +56,8 @@
from sklearn.externals import six
from sklearn.externals.six.moves import zip

from sklearn.utils.fixes import comb

from sklearn.svm import SVC

X = np.ones(10)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from abc import ABCMeta, abstractmethod

import numpy as np
from scipy.misc import logsumexp
from scipy.sparse import issparse

from .base import BaseEstimator, ClassifierMixin
Expand All @@ -28,6 +27,7 @@
from .preprocessing import label_binarize
from .utils import check_X_y, check_array, check_consistent_length
from .utils.extmath import safe_sparse_dot
from .utils.fixes import logsumexp
from .utils.multiclass import _check_partial_fit_first_call
from .utils.validation import check_is_fitted
from .externals import six
Expand Down
2 changes: 1 addition & 1 deletion sklearn/utils/extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import numpy as np
from scipy import linalg
from scipy.sparse import issparse, csr_matrix
from scipy.misc import logsumexp as scipy_logsumexp

from . import check_random_state, deprecated
from .fixes import np_version
from .fixes import logsumexp as scipy_logsumexp
from ._logistic_sigmoid import _log_logistic_sigmoid
from ..externals.six.moves import xrange
from .sparsefuncs_fast import csr_row_norms
Expand Down
13 changes: 9 additions & 4 deletions sklearn/utils/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# License: BSD 3 clause

import warnings
import sys
import functools
import os
import errno

Expand All @@ -36,6 +34,7 @@ def _parse_version(version_string):
version.append(x)
return tuple(version)


euler_gamma = getattr(np, 'euler_gamma',
0.577215664901532860606512090082402431)

Expand Down Expand Up @@ -142,11 +141,17 @@ def sparse_min_max(X, axis):
# Backport fix for scikit-learn/scikit-learn#2986 / scipy/scipy#4142
from ._scipy_sparse_lsqr_backport import lsqr as sparse_lsqr
else:
from scipy.sparse.linalg import lsqr as sparse_lsqr
from scipy.sparse.linalg import lsqr as sparse_lsqr # noqa


try: # SciPy >= 0.19
from scipy.special import comb, logsumexp
except ImportError:
from scipy.misc import comb, logsumexp # noqa


def parallel_helper(obj, methodname, *args, **kwargs):
"""Helper to workaround Python 2 limitations of pickling instance methods"""
"""Workaround for Python 2 limitations of pickling instance methods"""
return getattr(obj, methodname)(*args, **kwargs)


Expand Down
4 changes: 2 additions & 2 deletions sklearn/utils/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import numpy as np
import scipy.sparse as sp
from scipy.misc import comb as combinations
from numpy.testing import assert_array_almost_equal
from sklearn.utils.random import sample_without_replacement
from sklearn.utils.random import random_choice_csc
from sklearn.utils.fixes import comb

from sklearn.utils.testing import (
assert_raises,
Expand Down Expand Up @@ -88,7 +88,7 @@ def check_sample_int_distribution(sample_without_replacement):
# Counting the number of combinations is not as good as counting the
# the number of permutations. However, it works with sampling algorithm
# that does not provide a random permutation of the subset of integer.
n_expected = combinations(n_population, n_samples, exact=True)
n_expected = comb(n_population, n_samples, exact=True)

output = {}
for i in range(n_trials):
Expand Down

0 comments on commit 95aa295

Please sign in to comment.