Skip to content

Commit

Permalink
MAINT Parameters validation for sklearn.datasets.get_data_home (sciki…
Browse files Browse the repository at this point in the history
…t-learn#26260)

Co-authored-by: Jérémie du Boisberranger <[email protected]>
  • Loading branch information
jiawei-zhang-a and jeremiedbb committed Apr 25, 2023
1 parent bc3a19d commit 363c633
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
7 changes: 6 additions & 1 deletion sklearn/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
RemoteFileMetadata = namedtuple("RemoteFileMetadata", ["filename", "url", "checksum"])


@validate_params(
{
"data_home": [str, os.PathLike, None],
}
)
def get_data_home(data_home=None) -> str:
"""Return the path of the scikit-learn data directory.
Expand All @@ -58,7 +63,7 @@ def get_data_home(data_home=None) -> str:
Returns
-------
data_home: str
data_home: str or path-like, default=None
The path to scikit-learn data directory.
"""
if data_home is None:
Expand Down
2 changes: 1 addition & 1 deletion sklearn/datasets/_openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ def fetch_openml(
data_home = None
else:
data_home = get_data_home(data_home=data_home)
data_home = join(data_home, "openml")
data_home = join(str(data_home), "openml")

# check valid function arguments. data_id XOR (name, version) should be
# provided
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _check_function_param_validation(
"sklearn.datasets.fetch_olivetti_faces",
"sklearn.datasets.fetch_rcv1",
"sklearn.datasets.fetch_species_distributions",
"sklearn.datasets.get_data_home",
"sklearn.datasets.load_breast_cancer",
"sklearn.datasets.load_diabetes",
"sklearn.datasets.load_digits",
Expand Down

0 comments on commit 363c633

Please sign in to comment.