Skip to content

Commit

Permalink
Add first & last aggregation ops (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrikDurdevic committed Aug 18, 2023
1 parent 8e79f44 commit 2c7cd68
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 13 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ v0.6.0 (, 2023)
* Add Airbnb Reviews dataset [#125][#125]
* Add Store dataset [#131][#131]
* Add Order By Operation (OrderByOp, along with IdentityOp and TransformationOpBase) [#138][#138]
* Add First and Last Aggregation Operations [#144][#144]
* Fixes
* Rename `_execute_operations_on_df` to `target` in executed prediction problem dataframe [#124][#124]
* Clean up operation description generation [#118][#118]
Expand All @@ -23,6 +24,7 @@ v0.6.0 (, 2023)
[#125]: <https://github.com/trane-dev/Trane/pull/125>
[#131]: <https://github.com/trane-dev/Trane/pull/131>
[#138]: <https://github.com/trane-dev/Trane/pull/138>
[#144]: <https://github.com/trane-dev/Trane/pull/144>


v0.5.0 (July 27, 2023)
Expand Down
13 changes: 11 additions & 2 deletions tests/integration_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
AvgAggregationOp,
CountAggregationOp,
ExistsAggregationOp,
FirstAggregationOp,
LastAggregationOp,
MajorityAggregationOp,
MaxAggregationOp,
MinAggregationOp,
Expand All @@ -28,8 +30,15 @@
MaxAggregationOp: " the maximum <{}> in all related records",
MinAggregationOp: " the minimum <{}> in all related records",
MajorityAggregationOp: " the majority <{}> in all related records",
CountAggregationOp: "the number of records",
ExistsAggregationOp: "if there exists a record",
CountAggregationOp: " the number of records",
ExistsAggregationOp: " if there exists a record",
FirstAggregationOp: " the first <{}> in all related records",
LastAggregationOp: " the last <{}> in all related records",
}

transform_op_str_dict = {
IdentityOp: "",
OrderByOp: " sorted by <{}>",
}

transform_op_str_dict = {
Expand Down
16 changes: 16 additions & 0 deletions tests/ops/test_aggregation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from trane.ops.aggregation_ops import (
AvgAggregationOp,
CountAggregationOp,
FirstAggregationOp,
LastAggregationOp,
MajorityAggregationOp,
MaxAggregationOp,
MinAggregationOp,
Expand Down Expand Up @@ -55,6 +57,20 @@ def test_min_agg_op(df):
assert "the minimum <col> in all related records" in op.generate_description()


def test_first_agg_op(df):
op = FirstAggregationOp("col")
output = op(df)
assert output == df["col"].iloc[0]
assert "the first <col> in all related records" in op.generate_description()


def test_last_agg_op(df):
op = LastAggregationOp("col")
output = op(df)
assert output == df["col"].iloc[-1]
assert "the last <col> in all related records" in op.generate_description()


@pytest.mark.parametrize(
"dtype",
[
Expand Down
4 changes: 4 additions & 0 deletions tests/ops/test_op_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
AvgAggregationOp,
CountAggregationOp,
ExistsAggregationOp,
FirstAggregationOp,
LastAggregationOp,
MajorityAggregationOp,
MaxAggregationOp,
MinAggregationOp,
Expand Down Expand Up @@ -33,6 +35,8 @@ def test_get_aggregation_ops():
MinAggregationOp,
MajorityAggregationOp,
ExistsAggregationOp,
FirstAggregationOp,
LastAggregationOp,
]


Expand Down
12 changes: 12 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
AvgAggregationOp,
CountAggregationOp,
ExistsAggregationOp,
FirstAggregationOp,
LastAggregationOp,
MajorityAggregationOp,
MaxAggregationOp,
MinAggregationOp,
Expand Down Expand Up @@ -237,6 +239,16 @@ def test_check_operations_cat():
result, modified_meta = _check_operations_valid(operations, table_meta)
assert result is True

# For each <id> predict the first <state> in all related records
operations = [AllFilterOp(None), IdentityOp(None), FirstAggregationOp("state")]
result, modified_meta = _check_operations_valid(operations, table_meta)
assert result is True

# For each <id> predict the last <state> in all related records
operations = [AllFilterOp(None), IdentityOp(None), LastAggregationOp("state")]
result, modified_meta = _check_operations_valid(operations, table_meta)
assert result is True


def test_foreign_key():
table_meta = {
Expand Down
35 changes: 24 additions & 11 deletions trane/ops/aggregation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,27 @@ def label_function(self, dataslice):
return len(dataslice) > 0


# class FirstAggregationOp(AggregationOpBase):
# input_output_types = [("category", "category")]
# description = " the first <{}> in all related records"
# def label_function(self, dataslice):
# return dataslice[self.column_name].iloc[0]

# class LastAggregationOp(AggregationOpBase):
# input_output_types = [("category", "category")]
# description = " the last <{}> in all related records"
# def label_function(self, dataslice):
# return dataslice[self.column_name].iloc[-1]
class FirstAggregationOp(AggregationOpBase):
input_output_types = [("category", "category")]
description = " the first <{}> in all related records"

def generate_description(self):
return self.description.format(self.column_name)

def label_function(self, dataslice):
if len(dataslice) == 0:
return None
return dataslice[self.column_name].iloc[0]


class LastAggregationOp(AggregationOpBase):
input_output_types = [("category", "category")]
description = " the last <{}> in all related records"

def generate_description(self):
return self.description.format(self.column_name)

def label_function(self, dataslice):
if len(dataslice) == 0:
return None
return dataslice[self.column_name].iloc[-1]

0 comments on commit 2c7cd68

Please sign in to comment.