Skip to content

Commit

Permalink
Fix ZeroDivisionError bug (#10)
Browse files Browse the repository at this point in the history
When the model was trained using integer numbers as labels, and when one
of the labels was the 0 value, the class with this labels and all its
training documents were ignored. Therefore, if the problem being solved
was a binary class classification, the 0 class was never learned and as
such the model was trained using only a single class which lead to the
ZeroDivisionError.

The fix was really straightforward, the issue was cause by the following
condition in the SS3.learn() method:

    if not doc or not cat:
            return

Replacing the ``not cat`` with ``cat is None`` solved the issue.

Resolves: #10
  • Loading branch information
sergioburdisso committed Jun 19, 2020
1 parent fecceb6 commit 236a942
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pyss3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def __sn__(self, ngram, icat):

s = sum([min(v, 1) for v in m_values])

return pow((c - (s + 1)) / ((c - 1) * (s + 1)), self.__p__)
try:
return pow((c - (s + 1)) / ((c - 1) * (s + 1)), self.__p__)
except ZeroDivisionError: # if c <= 1
return 1.

def __sg_vanilla__(self, ngram, icat, cache=True):
"""The original significance (sg) function definition."""
Expand Down Expand Up @@ -1936,7 +1939,7 @@ def learn(self, doc, cat, n_grams=1, prep=True, update=True):
"""
self.__cv_cache__ = None

if not doc or not cat:
if not doc or cat is None:
return

try:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_pyss3.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,19 @@ def test_pyss3_ss3(mockers):
with pytest.raises(ValueError):
clf = SS3("hyperparameter")

# Using integer labels
test_x = ["this is the first document"] * 5 + ["this is the second document"] * 5
test_y = [0] * 5 + [1] * 5
clf = SS3()
clf.train(test_x, test_y)
assert clf.classify_label("this is the first document") == 0
assert clf.classify_label("this is the second document") == 1

# traning only with one category
clf = SS3()
clf.train(["this is the first document"], ["first"])

# training different cases
clf = SS3(
s=.45, l=.5, p=1, a=0,
cv_m=STR_NORM_GV_XAI, sn_m=STR_XAI
Expand Down

0 comments on commit 236a942

Please sign in to comment.