Skip to content

Commit

Permalink
BUG: Make NumExprFilter return ndarray.
Browse files Browse the repository at this point in the history
- Previously it was returning a DataFrame because of how we applied an &
  with a DataFrame mask.  The error was masked by the fact that
  `np.assert_array_equal` coerces inputs to arrays before comparing.

- Added `zp.utils.test_utils.check_arrays`, which checks type equality
  before calling `np.assert_array_equal`.
  • Loading branch information
ssanderson committed Aug 3, 2015
1 parent 67c56f7 commit 5da03d2
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/modelling/test_numerical_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
isnan,
zeros,
)
from numpy.testing import assert_array_equal
from pandas import (
DataFrame,
date_range,
Expand All @@ -30,6 +29,7 @@
NUMEXPR_MATH_FUNCS,
)
from zipline.modelling.factor import TestingFactor
from zipline.utils.test_utils import check_arrays


class F(TestingFactor):
Expand Down Expand Up @@ -67,7 +67,7 @@ def check_output(self, expr, expected):
[self.fake_raw_data[input_] for input_ in expr.inputs],
self.mask,
)
assert_array_equal(result, full((5, 5), expected))
check_arrays(result, expected)

def check_constant_output(self, expr, expected):
self.assertFalse(isnan(expected))
Expand Down
5 changes: 2 additions & 3 deletions zipline/modelling/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,10 @@ def compute_from_arrays(self, arrays, mask):
"""
Compute our result with numexpr, then apply `mask`.
"""
numexpr_result = super(NumExprFilter, self).compute_from_arrays(
return super(NumExprFilter, self).compute_from_arrays(
arrays,
mask,
)
return numexpr_result & mask
) & mask.values


class PercentileFilter(SingleInputMixin, Filter):
Expand Down
15 changes: 15 additions & 0 deletions zipline/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
)
from logbook import FileHandler
from mock import patch
from numpy.testing import assert_array_equal
import operator
from zipline.finance.blotter import ORDER_STATUS
from zipline.utils import security_list
Expand Down Expand Up @@ -311,3 +312,17 @@ def make_simple_asset_info(assets, start_date, end_date, symbols=None):
'exchange': 'TEST',
}
)


def check_arrays(left, right, err_msg='', verbose=True):
"""
Wrapper around np.assert_array_equal that also verifies that inputs are
ndarrays.
See Also
--------
np.assert_array_equal
"""
if type(left) != type(right):
raise AssertionError("%s != %s" % (type(left), type(right)))
return assert_array_equal(left, right, err_msg=err_msg, verbose=True)

0 comments on commit 5da03d2

Please sign in to comment.