Skip to content

Commit

Permalink
Use faster lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
bouweandela committed May 15, 2024
1 parent eb0b340 commit 24615a0
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 59 deletions.
132 changes: 75 additions & 57 deletions lib/iris/_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
"""Automatic concatenation of multiple cubes over one or more existing dimensions."""

from collections import namedtuple
from collections.abc import Sequence
import itertools
from typing import Iterable
from typing import Any, Iterable
import warnings

import dask
import dask.array as da
from dask.base import tokenize
import numpy as np
from xxhash import xxh3_64
from xxhash import xxh3_64_digest

from iris._lazy_data import concatenate as concatenate_arrays
from iris._lazy_data import is_masked_data
import iris.coords
import iris.cube
import iris.exceptions
Expand Down Expand Up @@ -308,11 +309,7 @@ def _hash_array(a: da.Array | np.ndarray) -> np.int64:
"""

def arrayhash(x):
value = xxh3_64(np.array(x.shape, dtype=np.uint).tobytes())
value.update(x.data.tobytes())
if is_masked_data(x):
value.update(x.mask.tobytes())
return np.frombuffer(value.digest(), dtype=np.int64)
return np.frombuffer(xxh3_64_digest(x.tobytes()), dtype=np.int64)

return da.reduction(
a,
Expand All @@ -325,20 +322,41 @@ def arrayhash(x):
)


class _ArrayHash:
def __init__(self, value: np.int64, chunks: tuple) -> None:
self.value = value
self.chunks = chunks
class _ArrayHash(namedtuple("ArrayHash", ["value", "chunks"])):
"""Container for a hash value and the chunks used when computing it.
def __eq__(self, other: "_ArrayHash") -> bool:
Parameters
----------
value : :class:`np.int64`
The hash value.
chunks : tuple
The chunks the array had when the hash was computed.
"""

__slots__ = ()

def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
if self.chunks != other.chunks:
raise ValueError("Unable to compare arrays with different chunks.")
raise ValueError(
"Unable to compare arrays with different chunks: "
f"{self.chunks} != {other.chunks}"
)
return self.value == other.value


def array_id(array: np.ndarray | da.Array) -> str:
"""Get a deterministic token representing `array`."""
if isinstance(array, np.ma.MaskedArray):
# Tokenizing a masked array is much slower than separately tokenizing
# the data and mask.
return tokenize((tokenize(array.data), tokenize(array.mask)))
return tokenize(array)


def _compute_hashes(arrays: Iterable[np.ndarray | da.Array]) -> dict[str, _ArrayHash]:
"""Compute hashes for the arrays that will be compared."""
hashes = {}

def is_numerical(dtype):
return np.issubdtype(dtype, np.bool_) or np.issubdtype(dtype, np.number)
Expand All @@ -350,9 +368,11 @@ def group_key(a):
dtype = str(a.dtype)
return a.shape, dtype

hashes = {}

arrays = sorted(arrays, key=group_key)
for _, group in itertools.groupby(arrays, key=group_key):
group = list(group)
for _, group_iter in itertools.groupby(arrays, key=group_key):
group = list(group_iter)
# Unify dtype for numerical arrays, as the hash depends on it
if is_numerical(group[0].dtype):
dtype = np.result_type(*group)
Expand All @@ -364,23 +384,23 @@ def group_key(a):
argpairs = [(a, indices) for a in same_dtype_arrays]
rechunked_arrays = da.core.unify_chunks(*itertools.chain(*argpairs))[1]
for array, rechunked in zip(group, rechunked_arrays):
hashes[dask.base.tokenize(array)] = (
hashes[array_id(array)] = (
_hash_array(rechunked),
rechunked.chunks,
)

result = dask.compute(hashes)[0]
return {k: _ArrayHash(*v) for k, v in result.items()}
hashes = dask.compute(hashes)[0]
return {k: _ArrayHash(*v) for k, v in hashes.items()}


def concatenate(
cubes,
error_on_mismatch=False,
check_aux_coords=True,
check_cell_measures=True,
check_ancils=True,
check_derived_coords=True,
):
cubes: Sequence[iris.cube.Cube],
error_on_mismatch: bool = False,
check_aux_coords: bool = True,
check_cell_measures: bool = True,
check_ancils: bool = True,
check_derived_coords: bool = True,
) -> iris.cube.CubeList:
"""Concatenate the provided cubes over common existing dimensions.
Parameters
Expand Down Expand Up @@ -418,7 +438,7 @@ def concatenate(
A :class:`iris.cube.CubeList` of concatenated :class:`iris.cube.Cube` instances.
"""
proto_cubes = []
proto_cubes: list[_ProtoCube] = []
# Initialise the nominated axis (dimension) of concatenation
# which requires to be negotiated.
axis = None
Expand Down Expand Up @@ -718,7 +738,7 @@ def match(self, other, error_on_mismatch):
class _CoordSignature:
"""Template for identifying a specific type of :class:`iris.cube.Cube` based on its coordinates."""

def __init__(self, cube_signature):
def __init__(self, cube_signature: _CubeSignature) -> None:
"""Represent the coordinate metadata.
Represent the coordinate metadata required to identify suitable
Expand All @@ -737,7 +757,7 @@ def __init__(self, cube_signature):
self.derived_coords_and_dims = cube_signature.derived_coords_and_dims
self.dim_coords = cube_signature.dim_coords
self.dim_mapping = cube_signature.dim_mapping
self.dim_extents = []
self.dim_extents: list[_CoordExtent] = []
self.dim_order = [
metadata.kwargs["order"] for metadata in cube_signature.dim_metadata
]
Expand All @@ -746,7 +766,10 @@ def __init__(self, cube_signature):
self._calculate_extents()

@staticmethod
def _cmp(coord, other):
def _cmp(
coord: iris.coords.DimCoord,
other: iris.coords.DimCoord,
) -> tuple[bool, bool]:
"""Compare the coordinates for concatenation compatibility.
Returns
Expand All @@ -757,22 +780,17 @@ def _cmp(coord, other):
"""
# A candidate axis must have non-identical coordinate points.
candidate_axis = not array_equal(coord.core_points(), other.core_points())
candidate_axis = not array_equal(coord.points, other.points)

if candidate_axis:
# Ensure both have equal availability of bounds.
result = (coord.core_bounds() is None) == (other.core_bounds() is None)
else:
if coord.core_bounds() is not None and other.core_bounds() is not None:
# Ensure equality of bounds.
result = array_equal(coord.core_bounds(), other.core_bounds())
else:
# Ensure both have equal availability of bounds.
result = coord.core_bounds() is None and other.core_bounds() is None
# Ensure both have equal availability of bounds.
result = coord.has_bounds() == other.has_bounds()
if result and not candidate_axis:
# Ensure equality of bounds.
result = array_equal(coord.bounds, other.bounds)

return result, candidate_axis

def candidate_axis(self, other):
def candidate_axis(self, other: "_CoordSignature") -> int | None:
"""Determine the candidate axis of concatenation with the given coordinate signature.
If a candidate axis is found, then the coordinate
Expand Down Expand Up @@ -804,13 +822,13 @@ def candidate_axis(self, other):
# Only permit one degree of dimensional freedom when
# determining the candidate axis of concatenation.
if result and len(candidate_axes) == 1:
result = candidate_axes[0]
axis = candidate_axes[0]
else:
result = None
axis = None

return result
return axis

def _calculate_extents(self):
def _calculate_extents(self) -> None:
"""Calculate the extent over each dimension coordinates points and bounds."""
self.dim_extents = []
for coord, order in zip(self.dim_coords, self.dim_order):
Expand Down Expand Up @@ -950,15 +968,15 @@ def concatenate(self):

def register(
self,
cube,
hashes,
axis=None,
error_on_mismatch=False,
check_aux_coords=False,
check_cell_measures=False,
check_ancils=False,
check_derived_coords=False,
):
cube: iris.cube.Cube,
hashes: dict[str, _ArrayHash],
axis: int | None = None,
error_on_mismatch: bool = False,
check_aux_coords: bool = False,
check_cell_measures: bool = False,
check_ancils: bool = False,
check_derived_coords: bool = False,
) -> bool:
"""Determine if the given source-cube is suitable for concatenation.
Determine if the given source-cube is suitable for concatenation
Expand Down Expand Up @@ -1032,7 +1050,7 @@ def register(
warnings.warn(msg, category=iris.warnings.IrisUserWarning)

def get_hash(array):
return hashes[dask.base.tokenize(array)]
return hashes[array_id(array)]

def get_hashes(coord):
result = []
Expand Down
3 changes: 1 addition & 2 deletions lib/iris/tests/unit/concatenate/test_hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""Test array hashing in :mod:`iris._concatenate`."""

import dask.array as da
from dask.base import tokenize
import numpy as np
import pytest

Expand All @@ -30,7 +29,7 @@
)
def test_compute_hashes(a, b, eq):
hashes = _concatenate._compute_hashes([a, b])
assert eq == (hashes[tokenize(a)] == hashes[tokenize(b)])
assert eq == (hashes[_concatenate.array_id(a)] == hashes[_concatenate.array_id(b)])


def test_arrayhash_incompatible_chunks_raises():
Expand Down

0 comments on commit 24615a0

Please sign in to comment.