Skip to content

Commit

Permalink
Add multilabel supp. to grid_search()[1/2] (#5)
Browse files Browse the repository at this point in the history
Evaluation.grid_search() now supports multi-label classification using
the "test" method. It supports all previous standard metrics (precision,
recall, f1-score, accuracy) plus two new ones, 'hamming-lose' and
'exact-match' (equivalent to 'accuracy').
  • Loading branch information
sergioburdisso committed May 16, 2020
1 parent aacd3a0 commit 925156d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
47 changes: 30 additions & 17 deletions pyss3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ def plot(html_path='./', open_browser=True):

@staticmethod
def get_best_hyperparameters(
metric='accuracy', metric_target='macro avg', tag=None, method=None, def_cat=None
metric=None, metric_target='macro avg', tag=None, method=None, def_cat=None
):
"""
Return the best hyperparameter values for the given metric.
Expand All @@ -962,10 +962,12 @@ def get_best_hyperparameters(
that is, whether we want to measure some averaging performance or the
performance on a particular category.
:param metric: the evaluation metric, options are: 'accuracy', 'f1-score',
'precision', or 'recall'. In addition, In multi-label
classification also 'hamming-loss' and 'exact-match'
(default: 'accuracy').
:param metric: the evaluation metric to return, options are:
'accuracy', 'f1-score', 'precision', or 'recall'
When working with multi-label classification problems,
two more options are allowed: 'hamming-loss' and 'exact-match'.
Note: exact match will produce the same result than 'accuracy'.
(default: 'accuracy', or 'hamming-loss' for multi-label case).
:type metric: str
:param metric_target: the target we aim at measuring with the given
metric. Options are: 'macro avg', 'micro avg',
Expand All @@ -991,15 +993,17 @@ def get_best_hyperparameters(
if not Evaluation.__clf__:
raise ValueError(ERROR_CNA)

if metric not in METRICS + GLOBAL_METRICS:
raise KeyError(ERROR_NAM % str(metric))

metric = metric if metric != STR_EXACT_MATCH else STR_ACCURACY

l_tag, l_method, l_def_cat = Evaluation.__get_last_evaluation__()
tag, method, def_cat = tag or l_tag, method or l_method, def_cat or l_def_cat
cache = Evaluation.__cache__[tag][method][def_cat]

multilabel = STR_HAMMING_LOSS in cache
metric = metric or (STR_ACCURACY if not multilabel else STR_HAMMING_LOSS)
metric = metric if metric != STR_EXACT_MATCH else STR_ACCURACY

if metric not in METRICS + GLOBAL_METRICS:
raise KeyError(ERROR_NAM % str(metric))

if metric_target not in AVGS and metric_target not in cache["categories"]:
raise KeyError(ERROR_NAT % str(metric_target))

Expand All @@ -1009,6 +1013,7 @@ def get_best_hyperparameters(
best = c_metric["best"]
else:
if metric_target in AVGS:
print(metric, metric_target)
best = c_metric[metric_target]["best"]
else:
best = c_metric["categories"][metric_target]["best"]
Expand Down Expand Up @@ -1084,7 +1089,7 @@ def show_best(tag=None, method=None, def_cat=None, metric=None, avg=None):
" Best %s: %s %s" % (
ps.green("hamming loss"),
ps.warning(round_fix(1 - best["value"])),
ps.blue("(s %s l %s p %s a %s)") % (
ps.blue("(s=%s, l=%s, p=%s, a=%s)") % (
best["s"], best["l"], best["p"], best["a"]
)
)
Expand All @@ -1095,7 +1100,7 @@ def show_best(tag=None, method=None, def_cat=None, metric=None, avg=None):
" Best %s: %s %s" % (
ps.green("accuracy" if not multilabel else "exact match ratio"),
ps.warning(best["value"]),
ps.blue("(s %s l %s p %s a %s)") % (
ps.blue("(s=%s, l=%s, p=%s, a=%s)") % (
best["s"], best["l"], best["p"], best["a"]
)
)
Expand All @@ -1111,7 +1116,7 @@ def show_best(tag=None, method=None, def_cat=None, metric=None, avg=None):
best = c_metric["categories"][cat]["best"]
print((" " * 8) + "%s: %s %s" % (
cat, ps.warning(best["value"]),
ps.blue("(s %s l %s p %s a %s)") % (
ps.blue("(s=%s, l=%s, p=%s, a=%s)") % (
best["s"], best["l"], best["p"], best["a"]
)
))
Expand All @@ -1125,7 +1130,7 @@ def show_best(tag=None, method=None, def_cat=None, metric=None, avg=None):
best = c_metric[av]["best"]
print((" " * 10) + "%s: %s %s" % (
ps.header(av), ps.warning(best["value"]),
ps.blue("(s %s l %s p %s a %s)")
ps.blue("(s=%s, l=%s, p=%s, a=%s)")
% (
best["s"], best["l"],
best["p"], best["a"]
Expand Down Expand Up @@ -1456,7 +1461,7 @@ def kfold_cross_validation(
def grid_search(
clf, x_data, y_data, s=None, l=None, p=None, a=None,
k_fold=None, n_grams=None, def_cat=STR_MOST_PROBABLE, prep=True,
tag=None, metric='accuracy', metric_target='macro avg', cache=True, extended_pbar=False
tag=None, metric=None, metric_target='macro avg', cache=True, extended_pbar=False
):
"""
Perform a grid search using the provided hyperparameter values.
Expand Down Expand Up @@ -1547,7 +1552,10 @@ def grid_search(
:type tag: str
:param metric: the evaluation metric to return, options are:
'accuracy', 'f1-score', 'precision', or 'recall'
(default: 'accuracy').
When working with multi-label classification problems,
two more options are allowed: 'hamming-loss' and 'exact-match'.
Note: exact match will produce the same result than 'accuracy'.
(default: 'accuracy', or 'hamming-loss' for multi-label case).
:type metric: str
:param metric_target: the target we aim at measuring with the given
metric. Options are: 'macro avg', 'micro avg',
Expand All @@ -1570,6 +1578,7 @@ def grid_search(
raise TypeError(ERROR_IKT)

from . import SS3
multilabel = clf.__multilabel__
Evaluation.set_classifier(clf)
n_grams = n_grams or (len(clf.__max_fr__[0]) if len(clf.__max_fr__) > 0 else 1)
tag = tag or Evaluation.__cache_get_default_tag__(clf, x_data, n_grams)
Expand All @@ -1584,7 +1593,11 @@ def grid_search(

Print.show()
if not k_fold: # if test
x_test, y_test = x_data, [clf.get_category_index(y) for y in y_data]
if not multilabel:
x_test, y_test = x_data, [clf.get_category_index(y) for y in y_data]
else:
x_test, y_test = x_data, [[clf.get_category_index(y) for y in yy]
for yy in y_data]
Evaluation.__grid_search_loop__(
clf, x_test, y_test, s, l, p, a, 1, 0,
def_cat, tag, clf.get_categories(), cache,
Expand Down
1 change: 1 addition & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def test_evaluation(mocker):
# grid_search
# OK
s0, l0, p0, a0 = Evaluation.grid_search(clf, x_data, y_data, s=ss)
s1, l1, p1, a1 = Evaluation.grid_search(clf_ml, x_data_ml, y_data_ml, s=ss, l=ll, p=pp)
s1, l1, p1, a1 = Evaluation.grid_search(clf, x_data, y_data, s=ss, l=ll, p=pp)
assert s0 == s1 and l0 == l1 and p0 == p1 and a0 == a1
s0, l0, p0, a0 = Evaluation.grid_search(clf, x_data, y_data, k_fold=4)
Expand Down

0 comments on commit 925156d

Please sign in to comment.