Skip to content

Commit

Permalink
Added a "searchsorted" method to regridding.find_indices().
Browse files Browse the repository at this point in the history
  • Loading branch information
byrdie committed Nov 30, 2023
1 parent 3057809 commit d410a25
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 3 deletions.
2 changes: 1 addition & 1 deletion benchmarks/find_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ def time_find_indices_1d(num: int, method: str):
time_find_indices_1d.setup = setup_find_indices_1d
time_find_indices_1d.params = (
list(np.linspace(0, 1e4, num=11, dtype=int)[1:]),
["brute"],
["brute", "searchsorted"],
)
time_find_indices_1d.param_names = ["num", "method"]
9 changes: 8 additions & 1 deletion regridding/_find_indices/_find_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from regridding import _util
from ._find_indices_brute import _find_indices_brute
from ._find_indices_searchsorted import _find_indices_searchsorted

__all__ = [
"find_indices",
Expand All @@ -14,7 +15,7 @@ def find_indices(
axis_input: None | int | tuple[int, ...] = None,
axis_output: None | int | tuple[int, ...] = None,
fill_value: None | int = None,
method: Literal["brute"] = "brute",
method: Literal["brute"] | Literal["searchsorted"] = "brute",
):
"""
Find the index of the input cell which contains the output vertex.
Expand Down Expand Up @@ -74,6 +75,12 @@ def find_indices(
vertices_output=vertices_output,
fill_value=fill_value,
)
elif method == "searchsorted":
indices_output = _find_indices_searchsorted(
vertices_input=vertices_input,
vertices_output=vertices_output,
fill_value=fill_value,
)
else:
raise ValueError(f"method `{method}` not recognized.")

Expand Down
66 changes: 66 additions & 0 deletions regridding/_find_indices/_find_indices_searchsorted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
import numba


def _find_indices_searchsorted(
vertices_input: tuple[np.ndarray, ...],
vertices_output: tuple[np.ndarray, ...],
fill_value: None | int = None,
) -> tuple[np.ndarray, ...]:
if len(vertices_input) == 1:
result = _find_indices_searchsorted_1d(
vertices_input=vertices_input,
vertices_output=vertices_output,
fill_value=fill_value,
)
else:
raise ValueError(
f"{len(vertices_input)}-dimensional searchsorted not supported"
)

return result


@numba.njit(inline="always")
def _find_indices_searchsorted_1d(
vertices_input: tuple[np.ndarray],
vertices_output: tuple[np.ndarray],
fill_value: int,
) -> tuple[np.ndarray]:
(x_input,) = vertices_input
(x_output,) = vertices_output

num_d, num_m = x_input.shape
num_d, num_i = x_output.shape

result = np.empty(shape=x_output.shape, dtype=np.int32)

for d in numba.prange(num_d):
x_input_d = x_input[d]
x_output_d = x_output[d]

result_d = np.searchsorted(
a=x_input_d,
v=x_output_d,
)

x_input_d_min = x_input_d[0]
for i in range(num_i):
x_output_di = x_output_d[i]
result_di = result_d[i]
result_di = result_di - 1
if x_output_di == x_input_d_min:
result_di = 0
elif result_di < 0:
result_di = fill_value
elif result_di > num_m:
result_di = fill_value
result_d[i] = result_di


# result_d = result_d - 1
# result_d[result_d < 0] = fill_value
# result_d[result_d >= num_m] = fill_value
result[d] = result_d

return (result,)
2 changes: 1 addition & 1 deletion regridding/_find_indices/_tests/test_find_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
),
],
)
@pytest.mark.parametrize("method", ["brute"])
@pytest.mark.parametrize("method", ["brute", "searchsorted"])
def test_find_indices_1d(
vertices_input: tuple[np.ndarray],
vertices_output: tuple[np.ndarray],
Expand Down

0 comments on commit d410a25

Please sign in to comment.