Skip to content

Commit

Permalink
Add multilabel supp. to kfold_cross_validation(#5)
Browse files Browse the repository at this point in the history
Evaluation.kfold_cross_validation() now supports multi-label
classification as well. It supports all previous standard metrics
(precision, recall, f1-score, accuracy) plus two new ones,
'hamming-lose' and 'exact-match' (equivalent to 'accuracy').
  • Loading branch information
sergioburdisso committed May 16, 2020
1 parent ef2419b commit aacd3a0
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 34 deletions.
63 changes: 43 additions & 20 deletions pyss3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.metrics import classification_report, accuracy_score, hamming_loss
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedKFold
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

import numpy as np
import unicodedata
Expand Down Expand Up @@ -458,7 +459,7 @@ def __cache_remove__(tag, method, def_cat, s, l, p, a, simulate=False):
@staticmethod
def __classification_report_k_fold__(
tag, method, def_cat, s, l, p, a, plot=True,
metric='accuracy', metric_target='macro avg'
metric=None, metric_target='macro avg'
):
"""Create the classification report for k-fold validations."""
Print.verbosity_region_begin(VERBOSITY.VERBOSE, force=True)
Expand All @@ -469,6 +470,10 @@ def __classification_report_k_fold__(
cache = Evaluation.__cache_get_evaluations__(tag, method, def_cat)
categories = cache["categories"]

multilabel = STR_HAMMING_LOSS in cache
metric = metric or (STR_ACCURACY if not multilabel else STR_HAMMING_LOSS)
metric = metric if metric != STR_EXACT_MATCH else STR_ACCURACY

name_width = max(len(cn) for cn in categories)
width = max(name_width, len(AVGS[-1]))
head_fmt = '{:>{width}s} ' + ' {:>9}' * len(METRICS)
Expand Down Expand Up @@ -496,22 +501,29 @@ def __classification_report_k_fold__(
report += '\n'

report += "\n\n %s: %.3f\n" % (
Print.style.bold("avg accuracy"), cache["accuracy"]["value"][s][l][p][a]
Print.style.bold("Avg. %s" % ("Exact Match Ratio" if multilabel else "Accuracy")),
cache["accuracy"]["value"][s][l][p][a]
)
if multilabel:
report += " %s: %.3f\n" % (
Print.style.bold("Avg. Hamming Loss"),
round_fix(1 - cache["hamming-loss"]["value"][s][l][p][a])
)

Print.show(report)
Print.verbosity_region_end()

if plot:
if plot and not multilabel:
Evaluation.__plot_confusion_matrices__(
cache["confusion_matrix"][s][l][p][a], categories,
r"$\sigma=%.3f; \lambda=%.3f; \rho=%.3f; \alpha=%.3f$"
%
(s, l, p, a)
)

if metric == STR_ACCURACY:
return cache[metric]["value"][s][l][p][a]
if metric in GLOBAL_METRICS:
val = cache[metric]["value"][s][l][p][a]
return val if metric != STR_HAMMING_LOSS else 1 - val
else:
if metric not in cache:
raise KeyError(ERROR_NAM % str(metric))
Expand Down Expand Up @@ -609,7 +621,7 @@ def __get_global_best__(values):
def __evaluation_result__(
clf, y_true, y_pred, categories, def_cat, cache=True, method="test",
tag=None, folder=False, plot=True, k_fold=1, i_fold=0, force_show=False,
metric='accuracy', metric_target='macro avg'
metric=None, metric_target='macro avg'
):
"""Compute evaluation results and save them to disk (cache)."""
import warnings
Expand All @@ -619,6 +631,8 @@ def __evaluation_result__(
Print.verbosity_region_begin(VERBOSITY.VERBOSE, force=True)

multilabel = clf.__multilabel__
metric = metric or (STR_ACCURACY if not multilabel else STR_HAMMING_LOSS)
metric = metric if metric != STR_EXACT_MATCH else STR_ACCURACY
hammingloss = None

if metric == STR_HAMMING_LOSS and not multilabel:
Expand Down Expand Up @@ -648,7 +662,7 @@ def __evaluation_result__(
Print.show("\n %s: %.3f" % (Print.style.bold("Accuracy"), accuracy))
else:
Print.show("\n %s: %.3f" % (Print.style.bold("Exact Match Ratio"), accuracy))
Print.show("\n %s: %.3f" % (Print.style.bold("Hamming Loss"), hammingloss))
Print.show(" %s: %.3f" % (Print.style.bold("Hamming Loss"), hammingloss))

if not multilabel:
unclassified = None
Expand Down Expand Up @@ -1262,6 +1276,7 @@ def test(
:rtype: float
:raises: EmptyModelError, KeyError, ValueError
"""
multilabel = clf.__multilabel__
Evaluation.set_classifier(clf)
tag = tag or Evaluation.__cache_get_default_tag__(clf, x_test)
s, l, p, a = clf.get_hyperparameters()
Expand All @@ -1274,10 +1289,6 @@ def test(
tag, def_cat, s, l, p, a
)

multilabel = clf.__multilabel__
metric = metric or (STR_ACCURACY if not multilabel else STR_HAMMING_LOSS)
metric = metric if metric != STR_EXACT_MATCH else STR_ACCURACY

# if not cached
if not y_pred or multilabel:
clf.set_hyperparameters(s, l, p, a)
Expand All @@ -1300,7 +1311,7 @@ def test(
@staticmethod
def kfold_cross_validation(
clf, x_train, y_train, k=4, n_grams=None, def_cat=STR_MOST_PROBABLE, prep=True,
tag=None, plot=True, metric='accuracy', metric_target='macro avg', cache=True
tag=None, plot=True, metric=None, metric_target='macro avg', cache=True
):
"""
Perform a Stratified k-fold cross validation on the given training set.
Expand Down Expand Up @@ -1349,7 +1360,10 @@ def kfold_cross_validation(
:type plot: bool
:param metric: the evaluation metric to return, options are:
'accuracy', 'f1-score', 'precision', or 'recall'
(default: 'accuracy').
When working with multi-label classification problems,
two more options are allowed: 'hamming-loss' and 'exact-match'.
Note: exact match will produce the same result than 'accuracy'.
(default: 'accuracy', or 'hamming-loss' for multi-label case).
:type metric: str
:param metric_target: the target we aim at measuring with the given
metric. Options are: 'macro avg', 'micro avg',
Expand All @@ -1375,6 +1389,9 @@ def kfold_cross_validation(
if n_grams is not None and (not isinstance(n_grams, int) or n_grams < 1):
raise ValueError(ERROR_INGV)

multilabel = clf.__multilabel__
kfold_split = MultilabelStratifiedKFold if multilabel else StratifiedKFold

Evaluation.set_classifier(clf)
n_grams = n_grams or (len(clf.__max_fr__[0]) if len(clf.__max_fr__) > 0 else 1)
tag = tag or Evaluation.__cache_get_default_tag__(clf, x_train, n_grams)
Expand All @@ -1387,16 +1404,15 @@ def kfold_cross_validation(
s, l, p, a = clf.get_hyperparameters()
categories = clf.get_categories()
x_train, y_train = np.array(x_train), np.array(y_train)
skf = StratifiedKFold(n_splits=k)
skf = kfold_split(n_splits=k)
pbar_desc = "k-fold validation"
progress_bar = tqdm(total=k, desc=pbar_desc)
for i_fold, (train_ix, test_ix) in enumerate(skf.split(x_train, y_train)):
y_train_split = membership_matrix(clf, y_train).todense() if multilabel else y_train
for i_fold, (train_ix, test_ix) in enumerate(skf.split(x_train, y_train_split)):
if not cache or not Evaluation.__cache_is_in__(
tag, method, def_cat, s, l, p, a
):
x_train_fold, y_train_fold = x_train[train_ix], y_train[train_ix]
y_test_fold = [clf.get_category_index(y) for y in y_train[test_ix]]
x_test_fold = x_train[test_ix]

clf_fold = SS3()
clf_fold.set_hyperparameters(s, l, p, a)
Expand All @@ -1405,6 +1421,13 @@ def kfold_cross_validation(
clf_fold.fit(x_train_fold, y_train_fold, n_grams,
prep=prep, leave_pbar=False)

if not multilabel:
y_test_fold = [clf_fold.get_category_index(y) for y in y_train[test_ix]]
else:
y_test_fold = [[clf_fold.get_category_index(y) for y in yy]
for yy in y_train[test_ix]]
x_test_fold = x_train[test_ix]

progress_bar.set_description_str(pbar_desc + " [classifying...]")
y_pred = clf_fold.predict(x_test_fold, def_cat,
prep=prep, labels=False, leave_pbar=False)
Expand Down Expand Up @@ -2203,7 +2226,7 @@ def membership_matrix(clf, y_data, labels=True, show_pbar=True):
and that the classifier ``clf`` has been trained on 3 categories whose labels are
'labelA', 'labelB', and 'labelC', then, we would have that:
>>> labels2membership(clf, [[], ['labelA'], ['labelB'], ['labelA', 'labelC']])
>>> membership_matrix(clf, [[], ['labelA'], ['labelB'], ['labelA', 'labelC']])
returns the following membership matrix:
Expand Down Expand Up @@ -2234,10 +2257,10 @@ def membership_matrix(clf, y_data, labels=True, show_pbar=True):
labels2index = dict([(c if labels else clf.get_category_index(c), i)
for i, c in enumerate(clf.get_categories())])
y_data_matrix = sparse.lil_matrix((len(y_data), len(labels2index)), dtype="b")

try:
li = np.array([[i, labels2index[l]] for i, ll in enumerate(y_data) for l in ll])
y_data_matrix[li[:, 0], li[:, 1]] = 1
if len(li) > 0:
y_data_matrix[li[:, 0], li[:, 1]] = 1
except KeyError as e:
raise ValueError("The `y_data` contains an unknown label (%s)" % str(e))

Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ six
cython
scikit-learn[alldeps]>=0.20
tqdm>=4.8.4
matplotlib
matplotlib
iterative-stratification
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ six
cython
scikit-learn[alldeps]>=0.20
tqdm>=4.8.4
matplotlib
matplotlib
iterative-stratification
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
'cython',
'scikit-learn[alldeps]>=0.20',
'tqdm>=4.8.4',
'matplotlib'],
'matplotlib',
'iterative-stratification'],
tests_require=['pytest',
'pytest-mock'
'pytest-cov>=2.5'
Expand All @@ -69,5 +70,6 @@
'cython',
'scikit-learn[alldeps]>=0.20',
'tqdm>=4.8.4',
'matplotlib'],
'matplotlib',
'iterative-stratification'],
entry_points={'console_scripts': ['pyss3=pyss3.cmd_line:main']})
28 changes: 27 additions & 1 deletion tests/dataset_ml/train/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,30 @@ this is the synthetic document number 6
>>>>>
7, 7 is the number of this document
>>>>>
this is the last document, and it is really toxic!
this is the last document, and it is really toxic!
>>>>>
this is the synthetic document
number 9
>>>>>
this is the synthetic document
number 10
>>>>>
this is document number 11
>>>>>
this is document number 12
>>>>>
this is document number 13
>>>>>
this is document number 14
>>>>>
this is document number 15
>>>>>
this is document number 16
>>>>>
this is document number 17
>>>>>
this is document number 18
>>>>>
this is document number 19
>>>>>
this is document number 20
14 changes: 13 additions & 1 deletion tests/dataset_ml/train/labels.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,16 @@ toxic,severe_toxic,obscene,insult



toxic
toxic
toxic
severe_toxic
obscene
insult
toxic,insult
severe_toxic,obscene
obscene
insult
toxic
severe_toxic
obscene
insult
2 changes: 1 addition & 1 deletion tests/test_pyss3.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_multilabel():
clf.fit(x_train, y_train)

assert sorted(clf.get_categories()) == ['insult', 'obscene', 'severe_toxic', 'toxic']
assert clf.classify_multilabel("this is a unknown document!") == []
assert clf.classify_multilabel("this is a unknown document!") == ['toxic']

y_pred = [[], ['toxic'], ['severe_toxic'], ['obscene'], ['insult'], ['toxic', 'insult']]

Expand Down
17 changes: 10 additions & 7 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ def test_evaluation(mocker):
pp = [0, 2]
x_data, y_data = Dataset.load_from_files(DATASET_PATH)
x_data_ml, y_data_ml = Dataset.load_from_files_multilabel(
path.join(DATASET_MULTILABEL_PATH, "train_files"),
path.join(DATASET_MULTILABEL_PATH, "file_labels.tsv")
path.join(DATASET_MULTILABEL_PATH, "train/docs.txt"),
path.join(DATASET_MULTILABEL_PATH, "train/labels.txt"),
sep_label=",",
sep_doc="\n>>>>>\n"
)

clf = SS3()
Expand Down Expand Up @@ -107,6 +109,7 @@ def test_evaluation(mocker):
['bla bla bla', "I love this love movie!"],
['pos', 'pos'],
plot=PY3) == 0.5
assert kfold_validation(clf_ml, x_data_ml, y_data_ml, plot=PY3) > 0
assert kfold_validation(clf, x_data, y_data, plot=PY3) > 0
s, l, p, a = clf.get_hyperparameters()
s0, l0, p0, a0 = Evaluation.grid_search(clf, x_data, y_data)
Expand All @@ -122,8 +125,8 @@ def test_evaluation(mocker):

# test
# OK
assert Evaluation.test(clf_ml, x_data_ml, y_data_ml, plot=PY3) == 1 / 3.
assert Evaluation.test(clf_ml, x_data_ml, y_data_ml, metric='exact-match', plot=PY3) == .5
assert Evaluation.test(clf_ml, x_data_ml, y_data_ml, plot=PY3) == .3125
assert Evaluation.test(clf_ml, x_data_ml, y_data_ml, metric='exact-match', plot=PY3) == .3
assert Evaluation.test(clf, x_data, y_data, def_cat='unknown', plot=PY3) == 1
assert Evaluation.test(clf, x_data, y_data, def_cat='neg', plot=PY3) == 1
assert Evaluation.test(clf, x_data, y_data, metric="f1-score", plot=PY3) == 1
Expand Down Expand Up @@ -271,6 +274,6 @@ def test_dataset():
sep_doc="\n>>>>>\n"
)

assert len(y_train) == len(y_train) and len(y_train) == 8
assert y_train == [[], ['toxic', 'severe_toxic', 'obscene', 'insult'],
[], [], [], [], [], ['toxic']]
assert len(y_train) == len(y_train) and len(y_train) == 20
assert y_train[:8] == [[], ['toxic', 'severe_toxic', 'obscene', 'insult'],
[], [], [], [], [], ['toxic']]

0 comments on commit aacd3a0

Please sign in to comment.