Skip to content

Commit

Permalink
fix aggregated_by for derived coords (#4947)
Browse files Browse the repository at this point in the history
* fix aggregated_by for derived coords

* fix and test for derived coords

* improve tests, add whatsnew
  • Loading branch information
stephenworsley committed Nov 9, 2022
1 parent 421e193 commit f58682d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
3 changes: 3 additions & 0 deletions docs/src/whatsnew/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ This document explains the changes made to Iris for this release
:meth:`~iris.cube.Cube.cell_measure` and :meth:`~iris.cube.Cube.ancillary_variable`.
(:issue:`4898`, :pull:`4928`)

#. `@stephenworsley`_ fixed a bug which caused derived coordinates to be realised
after calling :meth:`iris.cube.Cube.aggregated_by`. (:issue:`3637`, :pull:`4947`)


💣 Incompatible Changes
=======================
Expand Down
15 changes: 13 additions & 2 deletions lib/iris/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -4076,8 +4076,9 @@ def aggregated_by(
# coordinate dimension.
shared_coords = list(
filter(
lambda coord_: coord_ not in groupby_coords,
self.coords(contains_dimension=dimension_to_groupby),
lambda coord_: coord_ not in groupby_coords
and dimension_to_groupby in self.coord_dims(coord_),
self.dim_coords + self.aux_coords,
)
)

Expand Down Expand Up @@ -4109,6 +4110,11 @@ def aggregated_by(
for coord in groupby_coords + shared_coords:
aggregateby_cube.remove_coord(coord)

coord_mapping = {}
for coord in aggregateby_cube.coords():
orig_id = id(self.coord(coord))
coord_mapping[orig_id] = coord

# Determine the group-by cube data shape.
data_shape = list(self.shape + aggregator.aggregate_shape(**kwargs))
data_shape[dimension_to_groupby] = len(groupby)
Expand Down Expand Up @@ -4237,6 +4243,11 @@ def aggregated_by(
aggregateby_cube.add_aux_coord(
new_coord, self.coord_dims(lookup_coord)
)
coord_mapping[id(self.coord(lookup_coord))] = new_coord

aggregateby_cube._aux_factories = []
for factory in self.aux_factories:
aggregateby_cube.add_aux_factory(factory.updated(coord_mapping))

# Attach the aggregate-by data into the aggregate-by cube.
if aggregateby_weights is None:
Expand Down
48 changes: 48 additions & 0 deletions lib/iris/tests/unit/cube/test_Cube__aggregated_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from iris.coords import AncillaryVariable, AuxCoord, CellMeasure, DimCoord
from iris.cube import Cube
import iris.exceptions
from iris.tests.stock import realistic_4d


class Test_aggregated_by(tests.IrisTest):
Expand Down Expand Up @@ -841,5 +842,52 @@ def test_clim_in_no_clim_op(self):
self.assertFalse(categorised_coord.climatological)


class Test_aggregated_by__derived(tests.IrisTest):
def setUp(self):
self.cube = realistic_4d()[:, :10, :6, :8]
self.time_cat_coord = AuxCoord(
[0, 0, 1, 1, 2, 2], long_name="time_cat"
)
self.cube.add_aux_coord(self.time_cat_coord, 0)
height_data = np.zeros(self.cube.shape[1])
height_data[5:] = 1
self.height_cat_coord = AuxCoord(height_data, long_name="height_cat")
self.cube.add_aux_coord(self.height_cat_coord, 1)
self.aggregator = iris.analysis.MEAN

def test_grouped_dim(self):
"""
Check that derived coordinates are maintained when the coordinates they
derive from are aggregated.
"""
result = self.cube.aggregated_by(
self.height_cat_coord,
self.aggregator,
)
assert len(result.aux_factories) == 1
altitude = result.coord("altitude")
assert altitude.shape == (2, 6, 8)

# Check the bounds are derived as expected.
orig_alt_bounds = self.cube.coord("altitude").bounds
bounds_0 = orig_alt_bounds[0::5, :, :, 0]
bounds_1 = orig_alt_bounds[4::5, :, :, 1]
expected_bounds = np.stack([bounds_0, bounds_1], axis=-1)
assert np.array_equal(expected_bounds, result.coord("altitude").bounds)

def test_ungrouped_dim(self):
"""
Check that derived coordinates are preserved when aggregating along a
different axis.
"""
result = self.cube.aggregated_by(
self.time_cat_coord,
self.aggregator,
)
assert len(result.aux_factories) == 1
altitude = result.coord("altitude")
assert altitude == self.cube.coord("altitude")


if __name__ == "__main__":
tests.main()

0 comments on commit f58682d

Please sign in to comment.