diff --git a/benchmarks/find_indices.py b/benchmarks/find_indices.py index 3176b5c..0c896be 100644 --- a/benchmarks/find_indices.py +++ b/benchmarks/find_indices.py @@ -28,6 +28,6 @@ def time_find_indices_1d(num: int, method: str): time_find_indices_1d.setup = setup_find_indices_1d time_find_indices_1d.params = ( list(np.linspace(0, 1e4, num=11, dtype=int)[1:]), - ["brute"], + ["brute", "searchsorted"], ) time_find_indices_1d.param_names = ["num", "method"] diff --git a/regridding/_find_indices/_find_indices.py b/regridding/_find_indices/_find_indices.py index 41c736d..2d26d2e 100644 --- a/regridding/_find_indices/_find_indices.py +++ b/regridding/_find_indices/_find_indices.py @@ -2,6 +2,7 @@ import numpy as np from regridding import _util from ._find_indices_brute import _find_indices_brute +from ._find_indices_searchsorted import _find_indices_searchsorted __all__ = [ "find_indices", @@ -14,7 +15,7 @@ def find_indices( axis_input: None | int | tuple[int, ...] = None, axis_output: None | int | tuple[int, ...] = None, fill_value: None | int = None, - method: Literal["brute"] = "brute", + method: Literal["brute"] | Literal["searchsorted"] = "brute", ): """ Find the index of the input cell which contains the output vertex. @@ -74,6 +75,12 @@ def find_indices( vertices_output=vertices_output, fill_value=fill_value, ) + elif method == "searchsorted": + indices_output = _find_indices_searchsorted( + vertices_input=vertices_input, + vertices_output=vertices_output, + fill_value=fill_value, + ) else: raise ValueError(f"method `{method}` not recognized.") diff --git a/regridding/_find_indices/_find_indices_searchsorted.py b/regridding/_find_indices/_find_indices_searchsorted.py new file mode 100644 index 0000000..82ab141 --- /dev/null +++ b/regridding/_find_indices/_find_indices_searchsorted.py @@ -0,0 +1,66 @@ +import numpy as np +import numba + + +def _find_indices_searchsorted( + vertices_input: tuple[np.ndarray, ...], + vertices_output: tuple[np.ndarray, ...], + fill_value: None | int = None, +) -> tuple[np.ndarray, ...]: + if len(vertices_input) == 1: + result = _find_indices_searchsorted_1d( + vertices_input=vertices_input, + vertices_output=vertices_output, + fill_value=fill_value, + ) + else: + raise ValueError( + f"{len(vertices_input)}-dimensional searchsorted not supported" + ) + + return result + + +@numba.njit(inline="always") +def _find_indices_searchsorted_1d( + vertices_input: tuple[np.ndarray], + vertices_output: tuple[np.ndarray], + fill_value: int, +) -> tuple[np.ndarray]: + (x_input,) = vertices_input + (x_output,) = vertices_output + + num_d, num_m = x_input.shape + num_d, num_i = x_output.shape + + result = np.empty(shape=x_output.shape, dtype=np.int32) + + for d in numba.prange(num_d): + x_input_d = x_input[d] + x_output_d = x_output[d] + + result_d = np.searchsorted( + a=x_input_d, + v=x_output_d, + ) + + x_input_d_min = x_input_d[0] + for i in range(num_i): + x_output_di = x_output_d[i] + result_di = result_d[i] + result_di = result_di - 1 + if x_output_di == x_input_d_min: + result_di = 0 + elif result_di < 0: + result_di = fill_value + elif result_di > num_m: + result_di = fill_value + result_d[i] = result_di + + + # result_d = result_d - 1 + # result_d[result_d < 0] = fill_value + # result_d[result_d >= num_m] = fill_value + result[d] = result_d + + return (result,) diff --git a/regridding/_find_indices/_tests/test_find_indices.py b/regridding/_find_indices/_tests/test_find_indices.py index c5d27ea..43a0192 100644 --- a/regridding/_find_indices/_tests/test_find_indices.py +++ b/regridding/_find_indices/_tests/test_find_indices.py @@ -13,7 +13,7 @@ ), ], ) -@pytest.mark.parametrize("method", ["brute"]) +@pytest.mark.parametrize("method", ["brute", "searchsorted"]) def test_find_indices_1d( vertices_input: tuple[np.ndarray], vertices_output: tuple[np.ndarray],