Skip to content

Commit

Permalink
Add multilabel supp. to grid_search()[2/2] (#5)
Browse files Browse the repository at this point in the history
Evaluation.grid_search() now fully supports multi-label classification
using over a given test set or using k-fold cross-validation over a
training set. It supports all previous standard metrics (precision,
recall, and f1-score) plus two new ones, 'hamming-lose' and
'exact-match' (equivalent to 'accuracy').

Resolves: #5
  • Loading branch information
sergioburdisso committed May 16, 2020
1 parent 42bbc65 commit 79f1e9d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
29 changes: 21 additions & 8 deletions pyss3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def __evaluation_result__(
n_cats = len(categories)
if def_cat == STR_UNKNOWN:
if categories[-1] != STR_UNKNOWN_CATEGORY:
categories += [STR_UNKNOWN_CATEGORY]
categories = categories + [STR_UNKNOWN_CATEGORY]
y_pred = [y if y != IDX_UNKNOWN_CATEGORY else n_cats for y in y_pred]
else:
y_pred = membership_matrix(clf, y_pred, labels=False)
Expand Down Expand Up @@ -1426,6 +1426,8 @@ def kfold_cross_validation(
clf_fold.fit(x_train_fold, y_train_fold, n_grams,
prep=prep, leave_pbar=False)

x_train_fold, y_train_fold = None, None # free memory

if not multilabel:
y_test_fold = [clf_fold.get_category_index(y) for y in y_train[test_ix]]
else:
Expand All @@ -1444,6 +1446,7 @@ def kfold_cross_validation(
cache, method, tag,
plot=False, k_fold=k, i_fold=i_fold
)
x_test_fold, y_test_fold, y_pred = None, None, None # free memory

progress_bar.update(1)

Expand Down Expand Up @@ -1607,19 +1610,27 @@ def grid_search(
Print.verbosity_region_begin(VERBOSITY.NORMAL)

x_data, y_data = np.array(x_data), np.array(y_data)
skf = StratifiedKFold(n_splits=k_fold)
kfold_split = MultilabelStratifiedKFold if multilabel else StratifiedKFold
skf = kfold_split(n_splits=k_fold)
categories = clf.get_categories()

for i_fold, (train_ix, test_ix) in enumerate(
skf.split(x_data, y_data)
):
y_data_split = membership_matrix(clf, y_data).todense() if multilabel else y_data
k_fold_splits = enumerate(skf.split(x_data, y_data_split))
for i_fold, (train_ix, test_ix) in k_fold_splits:
x_train, y_train = x_data[train_ix], y_data[train_ix]
y_test = [clf.get_category_index(y) for y in y_data[test_ix]]
x_test = x_data[test_ix]
categories = clf.get_categories()

clf_fold = SS3()
clf_fold.fit(x_train, y_train, n_grams, prep=prep, leave_pbar=False)

x_train, y_train = None, None # free memory

if not multilabel:
y_test = [clf_fold.get_category_index(y) for y in y_data[test_ix]]
else:
y_test = [[clf_fold.get_category_index(y) for y in yy]
for yy in y_data[test_ix]]
x_test = x_data[test_ix]

Evaluation.__grid_search_loop__(
clf_fold, x_test, y_test, s, l, p, a, k_fold, i_fold,
def_cat, tag, categories, cache,
Expand All @@ -1628,6 +1639,8 @@ def grid_search(
)
Evaluation.__cache_update__()

x_test, y_test = None, None # free memory

Print.verbosity_region_end()

return Evaluation.get_best_hyperparameters(metric, metric_target)
Expand Down
1 change: 1 addition & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def test_evaluation(mocker):
s1, l1, p1, a1 = Evaluation.grid_search(clf_ml, x_data_ml, y_data_ml, s=ss, l=ll, p=pp)
s1, l1, p1, a1 = Evaluation.grid_search(clf, x_data, y_data, s=ss, l=ll, p=pp)
assert s0 == s1 and l0 == l1 and p0 == p1 and a0 == a1
s0, l0, p0, a0 = Evaluation.grid_search(clf_ml, x_data_ml, y_data_ml, k_fold=4)
s0, l0, p0, a0 = Evaluation.grid_search(clf, x_data, y_data, k_fold=4)
s0, l0, p0, a0 = Evaluation.grid_search(clf, x_data, y_data, def_cat='unknown', p=pp)
s1, l1, p1, a1 = Evaluation.grid_search(clf, x_data, y_data, def_cat='neg', p=pp)
Expand Down

0 comments on commit 79f1e9d

Please sign in to comment.