Skip to content

Commit

Permalink
Add multilabel support to train/fit methods (#5)
Browse files Browse the repository at this point in the history
The fit/train method now supports multilabel classification. It will
automatically determine if we're dealing with a multilabel
classification problem by looking at the first item of the `y_train`
list. If the first item is a list (of labels), i.e., if it's not a
single label, it will assume we're dealing with a multilabel
classification problem.
  • Loading branch information
sergioburdisso committed May 13, 2020
1 parent 03c9341 commit 4d00476
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 23 deletions.
81 changes: 59 additions & 22 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,7 +2173,8 @@ def fit(self, x_train, y_train, n_grams=1, prep=True, leave_pbar=True):
:param x_train: the list of documents
:type x_train: list (of str)
:param y_train: the list of document labels
:type y_train: list (of str)
:type y_train: list of str for singlelabel classification;
list of list of str for multilabel classification.
:param n_grams: indicates the maximum ``n``-grams to be learned
(e.g. a value of ``1`` means only 1-grams (words),
``2`` means 1-grams and 2-grams,
Expand All @@ -2184,36 +2185,64 @@ def fit(self, x_train, y_train, n_grams=1, prep=True, leave_pbar=True):
:param leave_pbar: controls whether to leave the progress bar or
remove it after finishing.
:type leave_pbar: bool
:raises: ValueError
"""
cats = sorted(list(set(y_train)))
stime = time()

x_train, y_train = list(x_train), list(y_train)

x_train = [
"".join([
x_train[i]
if x_train[i] and x_train[i][-1] == '\n'
else
x_train[i] + '\n'
for i in xrange(len(x_train))
if y_train[i] == cat
])
for cat in cats
]
y_train = list(cats)
if len(x_train) != len(y_train):
raise ValueError("`x_train` and `y_train` must have the same length")

if len(y_train) == 0:
raise ValueError("`x_train` and `y_train` are empty")

# if it's a multi-label classification problem
if is_a_collection(y_train[0]):
# flattening y_train
labels = [l for y in y_train for l in y]
else:
labels = y_train

cats = sorted(list(set(labels)))

# if it's a single-label classification problem
if not is_a_collection(y_train[0]):
x_train = [
"".join([
x_train[i]
if x_train[i] and x_train[i][-1] == '\n'
else
x_train[i] + '\n'
for i in xrange(len(x_train))
if y_train[i] == cat
])
for cat in cats
]
y_train = list(cats)

Print.info("about to start training", offset=1)
Print.verbosity_region_begin(VERBOSITY.NORMAL)
progress_bar = tqdm(total=len(x_train), desc="Training",
leave=leave_pbar, disable=Print.is_quiet())
for i in range(len(x_train)):
progress_bar.set_description_str("Training on '%s'" % str(y_train[i]))
self.learn(
x_train[i], y_train[i],
n_grams=n_grams, prep=prep, update=False
)
progress_bar.update(1)

# if it's a multi-label classification problem
if is_a_collection(y_train[0]):
__unknown__ = [STR_UNKNOWN_CATEGORY]
for i in range(len(x_train)):
for label in (y_train[i] if y_train[i] else __unknown__):
self.learn(
x_train[i], label,
n_grams=n_grams, prep=prep, update=False
)
progress_bar.update(1)
else:
for i in range(len(x_train)):
progress_bar.set_description_str("Training on '%s'" % str(y_train[i]))
self.learn(
x_train[i], y_train[i],
n_grams=n_grams, prep=prep, update=False
)
progress_bar.update(1)
progress_bar.close()
self.__prune_tries__()
Print.verbosity_region_end()
Expand Down Expand Up @@ -2558,6 +2587,14 @@ def list_hash(str_list):
return m.hexdigest()


def is_a_collection(o):
"""Return True when the object ``o`` is a collection."""
from sys import version_info
py2 = version_info[0] == 2
return hasattr(o, "__getitem__") and ((py2 and not isinstance(o, basestring)) or
(not py2 and not isinstance(o, (str, bytes))))


def vsum(v0, v1):
"""Vectorial version of sum."""
return [v0[i] + v1[i] for i in xrange(len(v0))]
Expand Down
22 changes: 22 additions & 0 deletions tests/test_pyss3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
pyss3.set_verbosity(VERBOSITY.QUIET)

DATASET_FOLDER = "dataset"
DATASET_MULTILABEL_FOLDER = "dataset_ml"

dataset_path = path.join(path.abspath(path.dirname(__file__)), DATASET_FOLDER)
dataset_multilabel_path = path.join(path.abspath(path.dirname(__file__)), DATASET_MULTILABEL_FOLDER)

x_train, y_train = Dataset.load_from_files(dataset_path, folder_label=False)
x_test = [
Expand Down Expand Up @@ -253,6 +255,22 @@ def test_pyss3_functions():
pyss3.mad([], 0)


def test_multilabel():
"""Test multilabel support."""
x_train, y_train = Dataset.load_from_files_multilabel(
path.join(dataset_multilabel_path, "train/docs.txt"),
path.join(dataset_multilabel_path, "train/labels.txt"),
sep_label=",",
sep_doc="\n>>>>>\n"
)

clf = SS3()
clf.fit(x_train, y_train)

assert sorted(clf.get_categories()) == ['[unknown]', 'insult', 'obscene',
'severe_toxic', 'toxic']


def test_pyss3_ss3(mockers):
"""Test SS3."""
clf = SS3(
Expand All @@ -274,6 +292,10 @@ def test_pyss3_ss3(mockers):
clf.predict(x_test)
with pytest.raises(pyss3.EmptyModelError):
clf.predict_proba(x_test)
with pytest.raises(ValueError):
clf.train(x_train, [])
with pytest.raises(ValueError):
clf.train([], [])

# train and predict/classify tests (model: terms are single words)
# cv_m=STR_NORM_GV_XAI, sn_m=STR_XAI
Expand Down
2 changes: 1 addition & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
import pytest
import pyss3

PY3 = sys.version_info[0] >= 3
DATASET_FOLDER = "dataset_mr"
DATASET_MULTILABEL_FOLDER = "dataset_ml"
PY3 = sys.version_info[0] >= 3
DATASET_PATH = path.join(path.abspath(path.dirname(__file__)), DATASET_FOLDER)
DATASET_MULTILABEL_PATH = path.join(path.abspath(path.dirname(__file__)), DATASET_MULTILABEL_FOLDER)
TMP_FOLDER = "tests/ss3_models/"
Expand Down

0 comments on commit 4d00476

Please sign in to comment.