Skip to content

Commit

Permalink
Renamed vertices_input and vertices_output arguments of `regriddi…
Browse files Browse the repository at this point in the history
…ng.find_indices()` to `coordinates_input` and `coordinates_output`
  • Loading branch information
byrdie committed Nov 30, 2023
1 parent c7e8df5 commit 7ce9481
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 55 deletions.
20 changes: 10 additions & 10 deletions benchmarks/find_indices.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
import numpy as np
import regridding

vertices_input = None
vertices_output = None
coordinates_input = None
coordinates_output = None


def setup_find_indices_1d(num: int, method: str):
global vertices_input
global vertices_output
vertices_input = (np.linspace(-1, 1, num=num),)
vertices_output = (np.linspace(-1.1, 1.1, num=num),)
global coordinates_input
global coordinates_output
coordinates_input = (np.linspace(-1, 1, num=num),)
coordinates_output = (np.linspace(-1.1, 1.1, num=num),)
regridding.find_indices(
vertices_input=vertices_input,
vertices_output=vertices_output,
coordinates_input=coordinates_input,
coordinates_output=coordinates_output,
method=method,
)


def time_find_indices_1d(num: int, method: str):
regridding.find_indices(
vertices_input=vertices_input,
vertices_output=vertices_output,
coordinates_input=coordinates_input,
coordinates_output=coordinates_output,
method=method,
)

Expand Down
32 changes: 16 additions & 16 deletions regridding/_find_indices/_find_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


def find_indices(
vertices_input: tuple[np.ndarray, ...],
vertices_output: tuple[np.ndarray, ...],
coordinates_input: tuple[np.ndarray, ...],
coordinates_output: tuple[np.ndarray, ...],
axis_input: None | int | tuple[int, ...] = None,
axis_output: None | int | tuple[int, ...] = None,
fill_value: None | int = None,
Expand All @@ -22,9 +22,9 @@ def find_indices(
Parameters
----------
vertices_input
coordinates_input
the source grid
vertices_output
coordinates_output
the destination grid
axis_input
the axes in the source grid to search
Expand All @@ -37,16 +37,16 @@ def find_indices(
"""

(
vertices_input,
vertices_output,
coordinates_input,
coordinates_output,
axis_input,
axis_output,
shape_input,
shape_output,
shape_orthogonal,
) = _util._normalize_input_output_vertices(
vertices_input=vertices_input,
vertices_output=vertices_output,
vertices_input=coordinates_input,
vertices_output=coordinates_output,
axis_input=axis_input,
axis_output=axis_output,
)
Expand All @@ -60,25 +60,25 @@ def find_indices(
shape_input_numba = tuple(shape_input[ax] for ax in axis_input)
shape_output_numba = tuple(shape_output[ax] for ax in axis_output)

vertices_input = tuple(
coordinates_input = tuple(
np.moveaxis(v, axis_input, axis_input_numba).reshape(-1, *shape_input_numba)
for v in vertices_input
for v in coordinates_input
)
vertices_output = tuple(
coordinates_output = tuple(
np.moveaxis(v, axis_output, axis_output_numba).reshape(-1, *shape_output_numba)
for v in vertices_output
for v in coordinates_output
)

if method == "brute":
indices_output = _find_indices_brute(
vertices_input=vertices_input,
vertices_output=vertices_output,
coordinates_input=coordinates_input,
coordinates_output=coordinates_output,
fill_value=fill_value,
)
elif method == "searchsorted":
indices_output = _find_indices_searchsorted(
vertices_input=vertices_input,
vertices_output=vertices_output,
coordinates_input=coordinates_input,
coordinates_output=coordinates_output,
fill_value=fill_value,
)
else:
Expand Down
20 changes: 10 additions & 10 deletions regridding/_find_indices/_find_indices_brute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,32 @@


def _find_indices_brute(
vertices_input: tuple[np.ndarray, ...],
vertices_output: tuple[np.ndarray, ...],
coordinates_input: tuple[np.ndarray, ...],
coordinates_output: tuple[np.ndarray, ...],
fill_value: None | int = None,
) -> tuple[np.ndarray, ...]:
if len(vertices_input) == 1:
if len(coordinates_input) == 1:
result = _find_indices_brute_1d(
vertices_input=vertices_input,
vertices_output=vertices_output,
coordinates_input=coordinates_input,
coordinates_output=coordinates_output,
fill_value=fill_value,
)
else:
raise ValueError(
f"{len(vertices_input)}-dimensional brute-force search not supported"
f"{len(coordinates_input)}-dimensional brute-force search not supported"
)

return result


@numba.njit(inline="always")
def _find_indices_brute_1d(
vertices_input: tuple[np.ndarray],
vertices_output: tuple[np.ndarray],
coordinates_input: tuple[np.ndarray],
coordinates_output: tuple[np.ndarray],
fill_value: int,
) -> tuple[np.ndarray]:
(x_input,) = vertices_input
(x_output,) = vertices_output
(x_input,) = coordinates_input
(x_output,) = coordinates_output

num_d, num_m = x_input.shape
num_d, num_i = x_output.shape
Expand Down
20 changes: 10 additions & 10 deletions regridding/_find_indices/_find_indices_searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,32 @@


def _find_indices_searchsorted(
vertices_input: tuple[np.ndarray, ...],
vertices_output: tuple[np.ndarray, ...],
coordinates_input: tuple[np.ndarray, ...],
coordinates_output: tuple[np.ndarray, ...],
fill_value: None | int = None,
) -> tuple[np.ndarray, ...]:
if len(vertices_input) == 1:
if len(coordinates_input) == 1:
result = _find_indices_searchsorted_1d(
vertices_input=vertices_input,
vertices_output=vertices_output,
coordinates_input=coordinates_input,
coordinates_output=coordinates_output,
fill_value=fill_value,
)
else:
raise ValueError(
f"{len(vertices_input)}-dimensional searchsorted not supported"
f"{len(coordinates_input)}-dimensional searchsorted not supported"
)

return result


@numba.njit(inline="always")
def _find_indices_searchsorted_1d(
vertices_input: tuple[np.ndarray],
vertices_output: tuple[np.ndarray],
coordinates_input: tuple[np.ndarray],
coordinates_output: tuple[np.ndarray],
fill_value: int,
) -> tuple[np.ndarray]:
(x_input,) = vertices_input
(x_output,) = vertices_output
(x_input,) = coordinates_input
(x_output,) = coordinates_output

num_d, num_m = x_input.shape
num_d, num_i = x_output.shape
Expand Down
18 changes: 9 additions & 9 deletions regridding/_find_indices/_tests/test_find_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@pytest.mark.parametrize(
argnames="vertices_input,vertices_output",
argnames="coordinates_input,coordinates_output",
argvalues=[
(
(np.linspace(-1, 1, num=32),),
Expand All @@ -15,19 +15,19 @@
)
@pytest.mark.parametrize("method", ["brute", "searchsorted"])
def test_find_indices_1d(
vertices_input: tuple[np.ndarray],
vertices_output: tuple[np.ndarray],
coordinates_input: tuple[np.ndarray],
coordinates_output: tuple[np.ndarray],
method: Literal["brute"],
):
result = regridding.find_indices(
vertices_input=vertices_input,
vertices_output=vertices_output,
coordinates_input=coordinates_input,
coordinates_output=coordinates_output,
method=method,
)

(vertices_input_x,) = vertices_input
(vertices_output_x,) = vertices_output
(coordinates_input_x,) = coordinates_input
(coordinates_output_x,) = coordinates_output
(result_x,) = result

assert np.all(vertices_input_x[result_x + 0] <= vertices_output_x)
assert np.all(vertices_input_x[result_x + 1] >= vertices_output_x)
assert np.all(coordinates_input_x[result_x + 0] <= coordinates_output_x)
assert np.all(coordinates_input_x[result_x + 1] >= coordinates_output_x)

0 comments on commit 7ce9481

Please sign in to comment.