diff --git a/sklearn/datasets/_svmlight_format_io.py b/sklearn/datasets/_svmlight_format_io.py index 2a141e1732ff7..991832c23c389 100644 --- a/sklearn/datasets/_svmlight_format_io.py +++ b/sklearn/datasets/_svmlight_format_io.py @@ -25,6 +25,7 @@ from .. import __version__ from ..utils import check_array, IS_PYPY +from ..utils._param_validation import validate_params, HasMethods if not IS_PYPY: from ._svmlight_format_fast import ( @@ -404,6 +405,17 @@ def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id): ) +@validate_params( + { + "X": ["array-like", "sparse matrix"], + "y": ["array-like", "sparse matrix"], + "f": [str, HasMethods(["write"])], + "zero_based": ["boolean"], + "comment": [str, bytes, None], + "query_id": ["array-like", None], + "multilabel": ["boolean"], + } +) def dump_svmlight_file( X, y, @@ -428,7 +440,7 @@ def dump_svmlight_file( Training vectors, where `n_samples` is the number of samples and `n_features` is the number of features. - y : {array-like, sparse matrix}, shape = [n_samples (, n_labels)] + y : {array-like, sparse matrix}, shape = (n_samples,) or (n_samples, n_labels) Target values. Class labels must be an integer or float, or array-like objects of integer or float for multilabel classifications. @@ -442,7 +454,7 @@ def dump_svmlight_file( Whether column indices should be written zero-based (True) or one-based (False). - comment : str, default=None + comment : str or bytes, default=None Comment to insert at the top of the file. This should be either a Unicode string, which will be encoded as UTF-8, or an ASCII byte string. @@ -459,7 +471,7 @@ def dump_svmlight_file( https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html). .. versionadded:: 0.17 - parameter *multilabel* to support multilabel datasets. + parameter `multilabel` to support multilabel datasets. """ if comment is not None: # Convert comment string to list of lines in UTF-8. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index c2f0f18a7825e..b50e469f80a1e 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -102,6 +102,7 @@ def _check_function_param_validation( "sklearn.cluster.ward_tree", "sklearn.covariance.empirical_covariance", "sklearn.covariance.shrunk_covariance", + "sklearn.datasets.dump_svmlight_file", "sklearn.datasets.fetch_california_housing", "sklearn.datasets.fetch_kddcup99", "sklearn.datasets.make_classification",