Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Sets max_samples=1 when it is a float and too low in RandomForestClassifier #25601

Merged
merged 17 commits into from
Mar 10, 2023

Conversation

JanFidor
Copy link
Contributor

Reference Issues/PRs

Fixes #24037
Superseded #25140

Credit to @mohitthakur13 @Da-Lan for the solution and to @sbendimerad for the contribution

Copy link
Member

@betatim betatim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Thanks for picking this up in attempt to wrap it up :)

@JanFidor
Copy link
Contributor Author

@betatim Hi, I merged main to keep the branch up to date, also It was a while since your approve and I wanted to make sure the PR doesn't get stalled by accident and ask if there's anything else I should do to get it merged :)

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR @JanFidor !

Comment on lines 149 to 150
- |Fix| :class:`ensemble.RandomForestClassifier` raise more descriptive ValueError when round(n_samples * max_samples) < 1
:pr:`25601` by :user:`Jan Fidor <JanFidor>`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a backward incompatible change, we will need to add this to the "Changed Model section" too. For example, the following works on main but will fail after this PR:

from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier

X, y = load_wine(return_X_y=True)

clf = RandomForestClassifier(max_samples=1e-4)
clf.fit(X,y)
Suggested change
- |Fix| :class:`ensemble.RandomForestClassifier` raise more descriptive ValueError when round(n_samples * max_samples) < 1
:pr:`25601` by :user:`Jan Fidor <JanFidor>`.
- |Fix| :meth:`ensemble.RandomForestClassifier.fit` raises a more descriptive `ValueError`
when `max_samples` is a float and `round(n_samples * max_samples) < 1`.
:pr:`25601` by :user:`Jan Fidor <JanFidor>`.

return round(n_samples * max_samples)
result = round(n_samples * max_samples)
if result < 1:
raise ValueError("round(`max_samples` * `n_samples`) must be >= 1")
Copy link
Member

@thomasjpfan thomasjpfan Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this does not look like the error message above, but I think the backticks does not add any information:

Suggested change
raise ValueError("round(`max_samples` * `n_samples`) must be >= 1")
raise ValueError("round(max_samples * n_samples) must be >= 1")

Copy link
Contributor Author

@JanFidor JanFidor Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I added them to stay consistent with this raise and I'm inclined to agree that the backticks are a little redundant. So if it's okay I'd like to delete them from error messages in this file to stay consistent

Copy link
Member

@thomasjpfan thomasjpfan Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if it's okay I'd like to delete them from error messages in this file to stay consistent

Even if it's small, I prefer not to expand the scope of this PR, so it's easier to review and merge. We can cleanup the file in a separate follow up PR.

def test_raises_descriptive_bootstrap_error():
X, y = datasets.load_wine(return_X_y=True)
forest = RandomForestClassifier(max_samples=1e-4, class_weight="balanced_subsample")
warn_msg = "round\\(`max_samples` \\* `n_samples`\\) must be >= 1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using re.escape makes the following look a little cleaner:

Suggested change
warn_msg = "round\\(`max_samples` \\* `n_samples`\\) must be >= 1"
warn_msg = re.escape("round(max_samples * n_samples) must be >= 1")

@@ -1807,3 +1807,11 @@ def test_read_only_buffer(monkeypatch):

clf = RandomForestClassifier(n_jobs=2, random_state=rng)
cross_val_score(clf, X, y, cv=2)


def test_raises_descriptive_bootstrap_error():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_raises_descriptive_bootstrap_error():
def test_raises_bootstrap_error_when_max_samples_too_low():
"""Check that an error is raised when max_samples is configured too low.
Non-regression test for gh-24037.
"""


def test_raises_descriptive_bootstrap_error():
X, y = datasets.load_wine(return_X_y=True)
forest = RandomForestClassifier(max_samples=1e-4, class_weight="balanced_subsample")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a pytest.mark.parametrize to check the normal case as well?

@pytest.mark.parametrize("class_weight", ["balanced_subsample", None])
def test_raises_bootstrap_error_when_max_samples_too_low(class_weight):
    ...
    forest = RandomForestClassifier(max_samples=1e-4, class_weight=class_weight)

@jeremiedbb
Copy link
Member

@thomasjpfan, @glemaitre, would it be bad to not error but just return 1: return max(round(n_samples * max_samples), 1) ?
This is a pretty common pattern in scikit-learn and we usually don't fail but return the minimum acceptable value.

@thomasjpfan
Copy link
Member

Yea, I'll be okay with returning one. I would have preferred ceil(n_samples * max_samples), but for backward compatibility max(round(...), 1) is okay with me.

@JanFidor
Copy link
Contributor Author

JanFidor commented Mar 3, 2023

@jeremiedbb @thomasjpfan just wanted to get confirmation, shall I make the change to stay backward compatible and delete the "Changed Model section" entry?

@jeremiedbb
Copy link
Member

Yes, you can make the modifications in this PR

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! May you update the title to reflect the new behavior in this PR? (The title will become the commit message)

The docstring for RandomForestClassifier and RandomForestRegressor needs to be updated:

- If float, then draw `max_samples * X.shape[0]` samples. Thus,
`max_samples` should be in the interval `(0.0, 1.0]`.

I prefer it to be explicit and say max(round(n_samples * max_samples), 1).

sklearn/ensemble/tests/test_forest.py Show resolved Hide resolved
sklearn/ensemble/tests/test_forest.py Outdated Show resolved Hide resolved
@JanFidor JanFidor changed the title Add a more descriptive error to _get_n_samples_bootstrap in RandomForestClassifier Return 1 for _get_n_samples_bootstrap, when samples too low in RandomForestClassifier Mar 9, 2023
Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @JanFidor

@thomasjpfan thomasjpfan changed the title Return 1 for _get_n_samples_bootstrap, when samples too low in RandomForestClassifier Sets max_samples=1 when it is a float and too low in RandomForestClassifier Mar 10, 2023
@thomasjpfan thomasjpfan changed the title Sets max_samples=1 when it is a float and too low in RandomForestClassifier FIX Sets max_samples=1 when it is a float and too low in RandomForestClassifier Mar 10, 2023
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thomasjpfan thomasjpfan enabled auto-merge (squash) March 10, 2023 14:54
@thomasjpfan thomasjpfan merged commit 01c8e0b into scikit-learn:main Mar 10, 2023
@JanFidor
Copy link
Contributor Author

Thanks for the reviews and help @betatim @thomasjpfan and @jeremiedbb !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RandomForestClassifier class_weight/max_samples interaction can lead to ungraceful and nondescriptive failure
4 participants