forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MAINT unpack 0-dim NumPy array instead of implicit conversion (scikit…
…-learn#26345) Co-authored-by: Jérémie du Boisberranger <[email protected]>
- Loading branch information
1 parent
66733c4
commit 9eea5b7
Showing
7 changed files
with
18 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
# Li Li <[email protected]> | ||
# Giuseppe Vettigli <[email protected]> | ||
# License: BSD 3 clause | ||
from collections.abc import Iterable | ||
from io import StringIO | ||
from numbers import Integral | ||
|
||
|
@@ -247,7 +248,7 @@ def get_color(self, value): | |
color = list(self.colors["rgb"][np.argmax(value)]) | ||
sorted_values = sorted(value, reverse=True) | ||
if len(sorted_values) == 1: | ||
alpha = 0 | ||
alpha = 0.0 | ||
else: | ||
alpha = (sorted_values[0] - sorted_values[1]) / (1 - sorted_values[1]) | ||
else: | ||
|
@@ -256,8 +257,6 @@ def get_color(self, value): | |
alpha = (value - self.colors["bounds"][0]) / ( | ||
self.colors["bounds"][1] - self.colors["bounds"][0] | ||
) | ||
# unpack numpy scalars | ||
alpha = float(alpha) | ||
# compute the color as alpha against white | ||
color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color] | ||
# Return html color code in #RRGGBB format | ||
|
@@ -277,8 +276,12 @@ def get_fill_color(self, tree, node_id): | |
if tree.n_outputs == 1: | ||
node_val = tree.value[node_id][0, :] / tree.weighted_n_node_samples[node_id] | ||
if tree.n_classes[0] == 1: | ||
# Regression | ||
# Regression or degraded classification with single class | ||
node_val = tree.value[node_id][0, :] | ||
if isinstance(node_val, Iterable) and self.colors["bounds"] is not None: | ||
# Only unpack the float only for the regression tree case. | ||
# Classification tree requires an Iterable in `get_color`. | ||
node_val = node_val.item() | ||
else: | ||
# If multi-output color node by impurity | ||
node_val = -tree.impurity[node_id] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters