Skip to content

Commit

Permalink
MAINT Param validation: better message when common test fails to raise (
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedbb committed Jun 28, 2023
1 parent 9ab298a commit d9212de
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
18 changes: 18 additions & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,16 @@ def _check_function_param_validation(
rf"The '{param_name}' parameter of {func_name} must be .* Got .* instead."
)

err_msg = (
f"{func_name} does not raise an informative error message when the "
f"parameter {param_name} does not have a valid type. If any Python type "
"is valid, the constraint should be 'no_validation'."
)

# First, check that the error is raised if param doesn't match any valid type.
with pytest.raises(InvalidParameterError, match=match):
func(**{**valid_required_params, param_name: param_with_bad_type})
pytest.fail(err_msg)

# Then, for constraints that are more than a type constraint, check that the
# error is raised if param does match a valid type but does not match any valid
Expand All @@ -107,8 +114,19 @@ def _check_function_param_validation(
except NotImplementedError:
continue

err_msg = (
f"{func_name} does not raise an informative error message when the "
f"parameter {param_name} does not have a valid value.\n"
"Constraints should be disjoint. For instance "
"[StrOptions({'a_string'}), str] is not a acceptable set of "
"constraint because generating an invalid string for the first "
"constraint will always produce a valid string for the second "
"constraint."
)

with pytest.raises(InvalidParameterError, match=match):
func(**{**valid_required_params, param_name: bad_value})
pytest.fail(err_msg)


PARAM_VALIDATION_FUNCTION_LIST = [
Expand Down
16 changes: 16 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4287,6 +4287,12 @@ def check_param_validation(name, estimator_orig):
# the method is not accessible with the current set of parameters
continue

err_msg = (
f"{name} does not raise an informative error message when the parameter"
f" {param_name} does not have a valid type. If any Python type is"
" valid, the constraint should be 'no_validation'."
)

with raises(InvalidParameterError, match=match, err_msg=err_msg):
if any(
isinstance(X_type, str) and X_type.endswith("labels")
Expand Down Expand Up @@ -4315,6 +4321,16 @@ def check_param_validation(name, estimator_orig):
# the method is not accessible with the current set of parameters
continue

err_msg = (
f"{name} does not raise an informative error message when the "
f"parameter {param_name} does not have a valid value.\n"
"Constraints should be disjoint. For instance "
"[StrOptions({'a_string'}), str] is not a acceptable set of "
"constraint because generating an invalid string for the first "
"constraint will always produce a valid string for the second "
"constraint."
)

with raises(InvalidParameterError, match=match, err_msg=err_msg):
if any(
X_type.endswith("labels")
Expand Down

0 comments on commit d9212de

Please sign in to comment.