-
-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX Sets max_samples=1 when it is a float and too low in RandomForestClassifier #25601
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Thanks for picking this up in attempt to wrap it up :)
@betatim Hi, I merged main to keep the branch up to date, also It was a while since your approve and I wanted to make sure the PR doesn't get stalled by accident and ask if there's anything else I should do to get it merged :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @JanFidor !
doc/whats_new/v1.3.rst
Outdated
- |Fix| :class:`ensemble.RandomForestClassifier` raise more descriptive ValueError when round(n_samples * max_samples) < 1 | ||
:pr:`25601` by :user:`Jan Fidor <JanFidor>`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a backward incompatible change, we will need to add this to the "Changed Model section" too. For example, the following works on main
but will fail after this PR:
from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier
X, y = load_wine(return_X_y=True)
clf = RandomForestClassifier(max_samples=1e-4)
clf.fit(X,y)
- |Fix| :class:`ensemble.RandomForestClassifier` raise more descriptive ValueError when round(n_samples * max_samples) < 1 | |
:pr:`25601` by :user:`Jan Fidor <JanFidor>`. | |
- |Fix| :meth:`ensemble.RandomForestClassifier.fit` raises a more descriptive `ValueError` | |
when `max_samples` is a float and `round(n_samples * max_samples) < 1`. | |
:pr:`25601` by :user:`Jan Fidor <JanFidor>`. |
sklearn/ensemble/_forest.py
Outdated
return round(n_samples * max_samples) | ||
result = round(n_samples * max_samples) | ||
if result < 1: | ||
raise ValueError("round(`max_samples` * `n_samples`) must be >= 1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this does not look like the error message above, but I think the backticks does not add any information:
raise ValueError("round(`max_samples` * `n_samples`) must be >= 1") | |
raise ValueError("round(max_samples * n_samples) must be >= 1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, I added them to stay consistent with this raise and I'm inclined to agree that the backticks are a little redundant. So if it's okay I'd like to delete them from error messages in this file to stay consistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if it's okay I'd like to delete them from error messages in this file to stay consistent
Even if it's small, I prefer not to expand the scope of this PR, so it's easier to review and merge. We can cleanup the file in a separate follow up PR.
def test_raises_descriptive_bootstrap_error(): | ||
X, y = datasets.load_wine(return_X_y=True) | ||
forest = RandomForestClassifier(max_samples=1e-4, class_weight="balanced_subsample") | ||
warn_msg = "round\\(`max_samples` \\* `n_samples`\\) must be >= 1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using re.escape
makes the following look a little cleaner:
warn_msg = "round\\(`max_samples` \\* `n_samples`\\) must be >= 1" | |
warn_msg = re.escape("round(max_samples * n_samples) must be >= 1") |
@@ -1807,3 +1807,11 @@ def test_read_only_buffer(monkeypatch): | |||
|
|||
clf = RandomForestClassifier(n_jobs=2, random_state=rng) | |||
cross_val_score(clf, X, y, cv=2) | |||
|
|||
|
|||
def test_raises_descriptive_bootstrap_error(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_raises_descriptive_bootstrap_error(): | |
def test_raises_bootstrap_error_when_max_samples_too_low(): | |
"""Check that an error is raised when max_samples is configured too low. | |
Non-regression test for gh-24037. | |
""" |
|
||
def test_raises_descriptive_bootstrap_error(): | ||
X, y = datasets.load_wine(return_X_y=True) | ||
forest = RandomForestClassifier(max_samples=1e-4, class_weight="balanced_subsample") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use a pytest.mark.parametrize
to check the normal case as well?
@pytest.mark.parametrize("class_weight", ["balanced_subsample", None])
def test_raises_bootstrap_error_when_max_samples_too_low(class_weight):
...
forest = RandomForestClassifier(max_samples=1e-4, class_weight=class_weight)
@thomasjpfan, @glemaitre, would it be bad to not error but just return 1: |
Yea, I'll be okay with returning one. I would have preferred |
@jeremiedbb @thomasjpfan just wanted to get confirmation, shall I make the change to stay backward compatible and delete the "Changed Model section" entry? |
Yes, you can make the modifications in this PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! May you update the title to reflect the new behavior in this PR? (The title will become the commit message)
The docstring for RandomForestClassifier
and RandomForestRegressor
needs to be updated:
scikit-learn/sklearn/ensemble/_forest.py
Lines 1286 to 1288 in 4180b07
- If float, then draw `max_samples * X.shape[0]` samples. Thus, | |
`max_samples` should be in the interval `(0.0, 1.0]`. | |
I prefer it to be explicit and say max(round(n_samples * max_samples), 1)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @JanFidor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks for the reviews and help @betatim @thomasjpfan and @jeremiedbb ! |
Reference Issues/PRs
Fixes #24037
Superseded #25140
Credit to @mohitthakur13 @Da-Lan for the solution and to @sbendimerad for the contribution