diff --git a/regridding/_regrid/_regrid_from_weights.py b/regridding/_regrid/_regrid_from_weights.py index a952185..e88528b 100644 --- a/regridding/_regrid/_regrid_from_weights.py +++ b/regridding/_regrid/_regrid_from_weights.py @@ -70,7 +70,7 @@ def regrid_from_weights( if values_output.shape != shape_output: raise ValueError(f"") values_output.fill(0) - values_output_copy = values_output + ndim_output = len(shape_output) axis_output = _util._normalize_axis(axis_output, ndim=ndim_output) @@ -83,17 +83,26 @@ def regrid_from_weights( values_input = np.moveaxis(values_input, axis_input, axis_input_numba) values_output = np.moveaxis(values_output, axis_output, axis_output_numba) + shape_output_tmp = values_output.shape + weights = numba.typed.List(weights.reshape(-1)) values_input = values_input.reshape(-1, *shape_input_numba) values_output = values_output.reshape(-1, *shape_output_numba) + values_input = np.ascontiguousarray(values_input) + values_output = np.ascontiguousarray(values_output) + _regrid_from_weights( weights=weights, values_input=values_input, values_output=values_output, ) - return values_output_copy + values_output = values_output.reshape(*shape_output_tmp) + + values_output = np.moveaxis(values_output, axis_output_numba, axis_output) + + return values_output @numba.njit()