Skip to content

Commit

Permalink
Added regridding.fill() function to fill in missing data using inte…
Browse files Browse the repository at this point in the history
…rpolation (#3)

Added `regridding.fill()` function to fill in missing data using interpolation.
  • Loading branch information
byrdie committed Mar 6, 2024
1 parent 9135aee commit 1a58915
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 0 deletions.
1 change: 1 addition & 0 deletions regridding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from ._weights import *
from ._interp_ndarray import *
from ._regrid import *
from ._fill import *
1 change: 1 addition & 0 deletions regridding/_fill/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._fill import *
99 changes: 99 additions & 0 deletions regridding/_fill/_fill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Sequence, Literal
import numpy as np
from ._gauss_seidel import fill_gauss_seidel

__all__ = [
"fill",
]


def fill(
a: np.ndarray,
where: None | np.ndarray = None,
axis: None | int | Sequence[int] = None,
method: Literal["gauss_seidel"] = "gauss_seidel",
**kwargs,
) -> np.ndarray:
"""
Fill an array with missing values by interpolating from the valid points.
Parameters
----------
a
The array with missing values to be filled
where
Boolean array of missing values.
If :obj:`None` (the default), all NaN values will be filled.
axis
The axes to use for interpolation.
If :obj:`None` (the default), interpolate along all the axes of `a`.
method
The interpolation method to use.
The only option is "gauss_seidel", which uses the Gauss-Seidel relaxation
technique to interpolate the valid data points.
kwargs
Additional method-specific keyword arguments.
For the Gauss-Seidel method, the valid keyword arguments are:
- ``num_iterations=100``, the number of red-black Gauss-Seidel iterations to perform.
Examples
--------
Set random elements of an array to NaN, and then fill in the missing elements
using the Gauss-Seidel relaxation method.
.. jupyter-execute::
import numpy as np
import matplotlib.pyplot as plt
import regridding
# Define the independent variables
x = 3 * np.pi * np.linspace(-1, 1, num=51)
y = 3 * np.pi * np.linspace(-1, 1, num=51)
x, y = np.meshgrid(x, y, indexing="ij")
# Define the array to remove elements from
a = np.cos(x) * np.cos(y)
# Define the elements of the array to remove
where = np.random.uniform(0, 1, size=a.shape) > 0.9
# Set random elements of the array to NaN
a_missing = a.copy()
a_missing[where] = np.nan
# Fill the missing elements using Gauss-Seidel relaxation
b = regridding.fill(a_missing, method="gauss_seidel", num_iterations=11)
# Plot the results
fig, axs = plt.subplots(
ncols=3,
figsize=(6, 3),
sharey=True,
constrained_layout=True,
)
kwargs_imshow = dict(
vmin=a.min(),
vmax=a.max(),
)
axs[0].imshow(a_missing, **kwargs_imshow);
axs[1].imshow(b, **kwargs_imshow);
axs[2].imshow(a - b, **kwargs_imshow);
axs[0].set_title("original array");
axs[1].set_title("filled array");
axs[2].set_title("difference");
"""

if where is None:
where = np.isnan(a)

if method == "gauss_seidel":
return fill_gauss_seidel(
a=a,
where=where,
axis=axis,
**kwargs,
)
else: # pragma: nocover
raise ValueError("Unrecognized method '{method}'")
48 changes: 48 additions & 0 deletions regridding/_fill/_fill_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
import numpy as np
import regridding

_num_x = 11
_num_y = 12
_num_t = 13


@pytest.mark.parametrize(
argnames="a,where,axis",
argvalues=[
(
np.random.uniform(0, 1, size=(_num_x, _num_y)),
np.random.uniform(0, 1, size=(_num_x, _num_y)) > 0.9,
None,
),
(
np.random.uniform(0, 1, size=(_num_t, _num_x, _num_y)),
np.random.uniform(0, 1, size=(_num_t, _num_x, _num_y)) > 0.9,
(~1, ~0),
),
(
np.sqrt(np.random.uniform(-0.1, 1, size=(_num_x, _num_t, _num_y))),
None,
(0, ~0),
),
],
)
@pytest.mark.parametrize("num_iterations", [11])
def test_fill_gauss_sidel_2d(
a: np.ndarray,
where: np.ndarray,
axis: None | tuple[int, ...],
num_iterations: int,
):
result = regridding.fill(
a=a,
where=where,
axis=axis,
method="gauss_seidel",
num_iterations=num_iterations,
)
if where is None:
where = np.isnan(a)

assert np.allclose(result[~where], a[~where])
assert np.all(result[where] != 0)
111 changes: 111 additions & 0 deletions regridding/_fill/_gauss_seidel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Sequence
import numpy as np
import numba
import regridding._util

__all__ = [
"fill_gauss_seidel",
]


def fill_gauss_seidel(
a: np.ndarray,
where: np.ndarray,
axis: None | int | Sequence[int],
num_iterations: int = 100,
) -> np.ndarray:

a = a.copy()

a, where = np.broadcast_arrays(a, where, subok=True)

a[where] = 0

axis = regridding._util._normalize_axis(axis=axis, ndim=a.ndim)
axis_numba = ~np.arange(len(axis))[::-1]

shape = a.shape
shape_numba = tuple(shape[ax] for ax in axis)

a = np.moveaxis(a, axis, axis_numba)
where = np.moveaxis(where, axis, axis_numba)

shape_moved = a.shape

a = a.reshape(-1, *shape_numba)
where = where.reshape(-1, *shape_numba)

if len(axis) == 2:
result = _fill_gauss_seidel_2d(
a=a,
where=where,
num_iterations=num_iterations,
)
else: # pragma: nocover
raise ValueError(
f"The number of interpolation axes, {len(axis)}," f"is not supported"
)

result = result.reshape(shape_moved)
result = np.moveaxis(result, axis_numba, axis)

return result


@numba.njit(parallel=True)
def _fill_gauss_seidel_2d(
a: np.ndarray,
where: np.ndarray,
num_iterations: int,
) -> np.ndarray:

num_t, num_y, num_x = a.shape

for t in numba.prange(num_t):
for k in range(num_iterations):
for is_odd in [False, True]:
_iteration_gauss_seidel_2d(
a=a,
where=where,
t=t,
num_x=num_x,
num_y=num_y,
is_odd=is_odd,
)

return a


@numba.njit(fastmath=True)
def _iteration_gauss_seidel_2d(
a: np.ndarray,
where: np.ndarray,
t: int,
num_x: int,
num_y: int,
is_odd: bool,
) -> None:

xmin, xmax = -1, 1
ymin, ymax = -1, 1

dx = (xmax - xmin) / (num_x - 1)
dy = (ymax - ymin) / (num_y - 1)

dxxinv = 1 / (dx * dx)
dyyinv = 1 / (dy * dy)

dcent = 1 / (2 * (dxxinv + dyyinv))

for j in range(num_y):
for i in range(num_x):
if (i + j) & 1 == is_odd:
if where[t, j, i]:
i9 = (i - 1) % num_x
i1 = (i + 1) % num_x
j9 = (j - 1) % num_y
j1 = (j + 1) % num_y

xterm = dxxinv * (a[t, j, i9] + a[t, j, i1])
yterm = dyyinv * (a[t, j9, i] + a[t, j1, i])
a[t, j, i] = (xterm + yterm) * dcent

0 comments on commit 1a58915

Please sign in to comment.