Skip to content

Commit

Permalink
Fixed bug in regridding.regrid_from_weights() where zeros were bein…
Browse files Browse the repository at this point in the history
…g returned if the input arrays were not contiguous.
  • Loading branch information
byrdie committed Jan 10, 2024
1 parent f2553eb commit 244aefb
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions regridding/_regrid/_regrid_from_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def regrid_from_weights(
if values_output.shape != shape_output:
raise ValueError(f"")
values_output.fill(0)
values_output_copy = values_output

ndim_output = len(shape_output)
axis_output = _util._normalize_axis(axis_output, ndim=ndim_output)

Expand All @@ -83,17 +83,26 @@ def regrid_from_weights(
values_input = np.moveaxis(values_input, axis_input, axis_input_numba)
values_output = np.moveaxis(values_output, axis_output, axis_output_numba)

shape_output_tmp = values_output.shape

weights = numba.typed.List(weights.reshape(-1))
values_input = values_input.reshape(-1, *shape_input_numba)
values_output = values_output.reshape(-1, *shape_output_numba)

values_input = np.ascontiguousarray(values_input)
values_output = np.ascontiguousarray(values_output)

_regrid_from_weights(
weights=weights,
values_input=values_input,
values_output=values_output,
)

return values_output_copy
values_output = values_output.reshape(*shape_output_tmp)

values_output = np.moveaxis(values_output, axis_output_numba, axis_output)

return values_output


@numba.njit()
Expand Down

0 comments on commit 244aefb

Please sign in to comment.