Skip to content

Commit

Permalink
TensorPandas fix arithemtic and comparison ops (#143)
Browse files Browse the repository at this point in the history
* Fix arithmetic ops, enable missing tests

* Enable comparison ops tests

* Fix rvalue for other cases
  • Loading branch information
BryanCutler committed Oct 21, 2020
1 parent f5f5f70 commit 9a0b13b
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 10 deletions.
31 changes: 27 additions & 4 deletions text_extensions_for_pandas/array/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import pandas as pd
from pandas.compat import set_function_name
from pandas.core import ops
from pandas.core.dtypes.generic import ABCSeries
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
from pandas.core.indexers import check_array_indexer, validate_indices


Expand Down Expand Up @@ -90,9 +90,32 @@ def _create_method(cls, op, coerce_to_dtype=True, result_dtype=None):

def _binop(self, other):
lvalues = self._tensor
rvalues = other._tensor if isinstance(other, (TensorArray, TensorElement)) else other
res = op(lvalues, rvalues)
return cls(res)

if isinstance(other, (ABCDataFrame, ABCSeries, ABCIndexClass)):
# Rely on pandas to unbox and dispatch to us.
return NotImplemented

# divmod returns a tuple
if op_name in ["__divmod__", "__rdivmod__"]:
# TODO: return tuple
# div, mod = result
raise NotImplementedError

if isinstance(other, (TensorArray, TensorElement)):
rvalues = other._tensor
else:
rvalues = other

result = op(lvalues, rvalues)

# Force a TensorArray if rvalue is not a scalar
if isinstance(self, TensorElement) and \
(not isinstance(other, TensorElement) or not np.isscalar(other)):
result_wrapped = TensorArray(result)
else:
result_wrapped = cls(result)

return result_wrapped

op_name = ops._get_op_name(op, True)
return set_function_name(_binop, op_name, cls)
Expand Down
97 changes: 91 additions & 6 deletions text_extensions_for_pandas/array/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,12 @@ def data(dtype):
return pd.array(values, dtype=dtype)


@pytest.fixture
def data_for_twos(dtype):
values = np.ones(100) * 2
return pd.array(values, dtype=dtype)


@pytest.fixture
def data_missing(dtype):
values = np.array([[np.nan], [9]])
Expand Down Expand Up @@ -720,6 +726,29 @@ def data_for_grouping(dtype):


# Can't import due to dependencies, taken from pandas.conftest import all_compare_operators
_all_arithmetic_operators = [
"__add__",
"__radd__",
"__sub__",
"__rsub__",
"__mul__",
"__rmul__",
"__floordiv__",
"__rfloordiv__",
"__truediv__",
"__rtruediv__",
"__pow__",
"__rpow__",
"__mod__",
"__rmod__",
]


@pytest.fixture(params=_all_arithmetic_operators)
def all_arithmetic_operators(request):
return request.param


@pytest.fixture(params=["__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"])
def all_compare_operators(request):
return request.param
Expand Down Expand Up @@ -795,19 +824,75 @@ def test_setitem_mask_boolean_array_with_na(self, data, box_in_series):
assert np.all(result == data[0])


@pytest.mark.skip("resolve errors")
class TestPandasMissing(base.BaseMissingTests):
pass
@pytest.mark.skip(reason="TypeError: No matching signature found")
def test_fillna_limit_pad(self, data_missing):
super().test_fillna_limit_pad(data_missing)

@pytest.mark.skip(reason="TypeError: No matching signature found")
def test_fillna_limit_backfill(self, data_missing):
super().test_fillna_limit_backfill(data_missing)

@pytest.mark.skip(reason="TypeError: No matching signature found")
def test_fillna_series_method(self, data_missing, fillna_method):
super().test_fillna_series_method(data_missing, fillna_method)


@pytest.mark.skip("resolve errors")
class TestPandasArithmeticOps(base.BaseArithmeticOpsTests):
pass

# Expected errors for tests
base.BaseArithmeticOpsTests.series_scalar_exc = None
base.BaseArithmeticOpsTests.series_array_exc = None
base.BaseArithmeticOpsTests.frame_scalar_exc = None
base.BaseArithmeticOpsTests.divmod_exc = NotImplementedError

def test_arith_series_with_array(self, data, all_arithmetic_operators):
""" Override because creates Series from list of TensorElements as dtype=object."""
# ndarray & other series
op_name = all_arithmetic_operators
s = pd.Series(data)
self.check_opname(
s, op_name, pd.Series([s.iloc[0]] * len(s), dtype=TensorDtype()), exc=self.series_array_exc
)

@pytest.mark.skip("resolve errors")
@pytest.mark.skip(reason="TensorArray does not error on ops")
def test_error(self, data, all_arithmetic_operators):
# other specific errors tested in the TensorArray specific tests
pass


#@pytest.mark.skip("resolve errors")
class TestPandasComparisonOps(base.BaseComparisonOpsTests):
pass

def _compare_other(self, s, data, op_name, other):
"""
Override to eval result of `all()` as a `ndarray`.
NOTE: test_compare_scalar uses value `0` for other.
"""
op = self.get_op_from_name(op_name)
if op_name == "__eq__":
assert not np.all(op(s, other).all())
elif op_name == "__ne__":
assert np.all(op(s, other).all())
elif op_name in ["__lt__", "__le__"]:
assert not np.all(op(s, other).all())
elif op_name in ["__gt__", "__ge__"]:
assert np.all(op(s, other).all())
else:
raise ValueError("Unexpected opname: {}".format(op_name))

def test_compare_scalar(self, data, all_compare_operators):
""" Override to change scalar value to something usable."""
op_name = all_compare_operators
s = pd.Series(data)
self._compare_other(s, data, op_name, -1)

def test_compare_array(self, data, all_compare_operators):
""" Override to change scalar value to something usable."""
op_name = all_compare_operators
s = pd.Series(data[1:])
other = pd.Series([data[0]] * len(s), dtype=TensorDtype())
self._compare_other(s, data, op_name, other)


@pytest.mark.skip("resolve errors")
Expand Down

0 comments on commit 9a0b13b

Please sign in to comment.