-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added
regridding.fill()
function to fill in missing data using inte…
…rpolation (#3) Added `regridding.fill()` function to fill in missing data using interpolation.
- Loading branch information
Showing
5 changed files
with
260 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
from ._weights import * | ||
from ._interp_ndarray import * | ||
from ._regrid import * | ||
from ._fill import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ._fill import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from typing import Sequence, Literal | ||
import numpy as np | ||
from ._gauss_seidel import fill_gauss_seidel | ||
|
||
__all__ = [ | ||
"fill", | ||
] | ||
|
||
|
||
def fill( | ||
a: np.ndarray, | ||
where: None | np.ndarray = None, | ||
axis: None | int | Sequence[int] = None, | ||
method: Literal["gauss_seidel"] = "gauss_seidel", | ||
**kwargs, | ||
) -> np.ndarray: | ||
""" | ||
Fill an array with missing values by interpolating from the valid points. | ||
Parameters | ||
---------- | ||
a | ||
The array with missing values to be filled | ||
where | ||
Boolean array of missing values. | ||
If :obj:`None` (the default), all NaN values will be filled. | ||
axis | ||
The axes to use for interpolation. | ||
If :obj:`None` (the default), interpolate along all the axes of `a`. | ||
method | ||
The interpolation method to use. | ||
The only option is "gauss_seidel", which uses the Gauss-Seidel relaxation | ||
technique to interpolate the valid data points. | ||
kwargs | ||
Additional method-specific keyword arguments. | ||
For the Gauss-Seidel method, the valid keyword arguments are: | ||
- ``num_iterations=100``, the number of red-black Gauss-Seidel iterations to perform. | ||
Examples | ||
-------- | ||
Set random elements of an array to NaN, and then fill in the missing elements | ||
using the Gauss-Seidel relaxation method. | ||
.. jupyter-execute:: | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import regridding | ||
# Define the independent variables | ||
x = 3 * np.pi * np.linspace(-1, 1, num=51) | ||
y = 3 * np.pi * np.linspace(-1, 1, num=51) | ||
x, y = np.meshgrid(x, y, indexing="ij") | ||
# Define the array to remove elements from | ||
a = np.cos(x) * np.cos(y) | ||
# Define the elements of the array to remove | ||
where = np.random.uniform(0, 1, size=a.shape) > 0.9 | ||
# Set random elements of the array to NaN | ||
a_missing = a.copy() | ||
a_missing[where] = np.nan | ||
# Fill the missing elements using Gauss-Seidel relaxation | ||
b = regridding.fill(a_missing, method="gauss_seidel", num_iterations=11) | ||
# Plot the results | ||
fig, axs = plt.subplots( | ||
ncols=3, | ||
figsize=(6, 3), | ||
sharey=True, | ||
constrained_layout=True, | ||
) | ||
kwargs_imshow = dict( | ||
vmin=a.min(), | ||
vmax=a.max(), | ||
) | ||
axs[0].imshow(a_missing, **kwargs_imshow); | ||
axs[1].imshow(b, **kwargs_imshow); | ||
axs[2].imshow(a - b, **kwargs_imshow); | ||
axs[0].set_title("original array"); | ||
axs[1].set_title("filled array"); | ||
axs[2].set_title("difference"); | ||
""" | ||
|
||
if where is None: | ||
where = np.isnan(a) | ||
|
||
if method == "gauss_seidel": | ||
return fill_gauss_seidel( | ||
a=a, | ||
where=where, | ||
axis=axis, | ||
**kwargs, | ||
) | ||
else: # pragma: nocover | ||
raise ValueError("Unrecognized method '{method}'") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import pytest | ||
import numpy as np | ||
import regridding | ||
|
||
_num_x = 11 | ||
_num_y = 12 | ||
_num_t = 13 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
argnames="a,where,axis", | ||
argvalues=[ | ||
( | ||
np.random.uniform(0, 1, size=(_num_x, _num_y)), | ||
np.random.uniform(0, 1, size=(_num_x, _num_y)) > 0.9, | ||
None, | ||
), | ||
( | ||
np.random.uniform(0, 1, size=(_num_t, _num_x, _num_y)), | ||
np.random.uniform(0, 1, size=(_num_t, _num_x, _num_y)) > 0.9, | ||
(~1, ~0), | ||
), | ||
( | ||
np.sqrt(np.random.uniform(-0.1, 1, size=(_num_x, _num_t, _num_y))), | ||
None, | ||
(0, ~0), | ||
), | ||
], | ||
) | ||
@pytest.mark.parametrize("num_iterations", [11]) | ||
def test_fill_gauss_sidel_2d( | ||
a: np.ndarray, | ||
where: np.ndarray, | ||
axis: None | tuple[int, ...], | ||
num_iterations: int, | ||
): | ||
result = regridding.fill( | ||
a=a, | ||
where=where, | ||
axis=axis, | ||
method="gauss_seidel", | ||
num_iterations=num_iterations, | ||
) | ||
if where is None: | ||
where = np.isnan(a) | ||
|
||
assert np.allclose(result[~where], a[~where]) | ||
assert np.all(result[where] != 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from typing import Sequence | ||
import numpy as np | ||
import numba | ||
import regridding._util | ||
|
||
__all__ = [ | ||
"fill_gauss_seidel", | ||
] | ||
|
||
|
||
def fill_gauss_seidel( | ||
a: np.ndarray, | ||
where: np.ndarray, | ||
axis: None | int | Sequence[int], | ||
num_iterations: int = 100, | ||
) -> np.ndarray: | ||
|
||
a = a.copy() | ||
|
||
a, where = np.broadcast_arrays(a, where, subok=True) | ||
|
||
a[where] = 0 | ||
|
||
axis = regridding._util._normalize_axis(axis=axis, ndim=a.ndim) | ||
axis_numba = ~np.arange(len(axis))[::-1] | ||
|
||
shape = a.shape | ||
shape_numba = tuple(shape[ax] for ax in axis) | ||
|
||
a = np.moveaxis(a, axis, axis_numba) | ||
where = np.moveaxis(where, axis, axis_numba) | ||
|
||
shape_moved = a.shape | ||
|
||
a = a.reshape(-1, *shape_numba) | ||
where = where.reshape(-1, *shape_numba) | ||
|
||
if len(axis) == 2: | ||
result = _fill_gauss_seidel_2d( | ||
a=a, | ||
where=where, | ||
num_iterations=num_iterations, | ||
) | ||
else: # pragma: nocover | ||
raise ValueError( | ||
f"The number of interpolation axes, {len(axis)}," f"is not supported" | ||
) | ||
|
||
result = result.reshape(shape_moved) | ||
result = np.moveaxis(result, axis_numba, axis) | ||
|
||
return result | ||
|
||
|
||
@numba.njit(parallel=True) | ||
def _fill_gauss_seidel_2d( | ||
a: np.ndarray, | ||
where: np.ndarray, | ||
num_iterations: int, | ||
) -> np.ndarray: | ||
|
||
num_t, num_y, num_x = a.shape | ||
|
||
for t in numba.prange(num_t): | ||
for k in range(num_iterations): | ||
for is_odd in [False, True]: | ||
_iteration_gauss_seidel_2d( | ||
a=a, | ||
where=where, | ||
t=t, | ||
num_x=num_x, | ||
num_y=num_y, | ||
is_odd=is_odd, | ||
) | ||
|
||
return a | ||
|
||
|
||
@numba.njit(fastmath=True) | ||
def _iteration_gauss_seidel_2d( | ||
a: np.ndarray, | ||
where: np.ndarray, | ||
t: int, | ||
num_x: int, | ||
num_y: int, | ||
is_odd: bool, | ||
) -> None: | ||
|
||
xmin, xmax = -1, 1 | ||
ymin, ymax = -1, 1 | ||
|
||
dx = (xmax - xmin) / (num_x - 1) | ||
dy = (ymax - ymin) / (num_y - 1) | ||
|
||
dxxinv = 1 / (dx * dx) | ||
dyyinv = 1 / (dy * dy) | ||
|
||
dcent = 1 / (2 * (dxxinv + dyyinv)) | ||
|
||
for j in range(num_y): | ||
for i in range(num_x): | ||
if (i + j) & 1 == is_odd: | ||
if where[t, j, i]: | ||
i9 = (i - 1) % num_x | ||
i1 = (i + 1) % num_x | ||
j9 = (j - 1) % num_y | ||
j1 = (j + 1) % num_y | ||
|
||
xterm = dxxinv * (a[t, j, i9] + a[t, j, i1]) | ||
yterm = dyyinv * (a[t, j9, i] + a[t, j1, i]) | ||
a[t, j, i] = (xterm + yterm) * dcent |