Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multidimensional interpolate_na() #6360

Open
iuryt opened this issue Mar 15, 2022 · 10 comments
Open

Multidimensional interpolate_na() #6360

iuryt opened this issue Mar 15, 2022 · 10 comments

Comments

@iuryt
Copy link

iuryt commented Mar 15, 2022

Is your feature request related to a problem?

I think that having a way to run a multidimensional interpolation for filling missing values would be awesome.

The code snippet below create a data and show the problem I am having now. If the data has some orientation, we couldn't simply interpolate dimensions separately.

import xarray as xr
import numpy as np

n = 30
x = xr.DataArray(np.linspace(0,2*np.pi,n),dims=['x'])
y = xr.DataArray(np.linspace(0,2*np.pi,n),dims=['y'])
z = (np.sin(x)*xr.ones_like(y))

mask = xr.DataArray(np.random.randint(0,1+1,(n,n)).astype('bool'),dims=['x','y'])

kw = dict(add_colorbar=False)

fig,ax = plt.subplots(1,3,figsize=(11,3))
z.plot(ax=ax[0],**kw)
z.where(mask).plot(ax=ax[1],**kw)
z.where(mask).interpolate_na('x').plot(ax=ax[2],**kw)

image

I tried to use advanced interpolation for that, but it doesn't look like the best solution.

zs = z.where(mask).stack(k=['x','y'])
zs = zs.where(np.isnan(zs),drop=True)
xi,yi = zs.k.x.drop('k'),zs.k.y.drop('k')
zi = z.interp(x=xi,y=yi)

fig,ax = plt.subplots()
z.where(mask).plot(ax=ax,**kw)
ax.scatter(xi,yi,c=zi,**kw,linewidth=1,edgecolor='k')

returns

image

Describe the solution you'd like

Simply z.interpolate_na(['x','y'])

Describe alternatives you've considered

I could extract the data to numpy and interpolate using scipy.interpolate.griddata, but this is not the way xarray should work.

Additional context

No response

@hafez-ahmad

This comment was marked as off-topic.

@dcherian

This comment was marked as off-topic.

@thomas-fred
Copy link

I'd also find this very useful

@TheJeran
Copy link

Bumping

@martin-wegmann
Copy link

This would be super useful!

@albertotb
Copy link

+1. As an alternative I think interpolate_na from rioxarray supports this: https://corteva.github.io/rioxarray/html/examples/interpolate_na.html

@keewis
Copy link
Collaborator

keewis commented Aug 8, 2024

you might be interested in pyinterp. With some extreme tuning, this can even reconstruct the original image (set nx=1, ny=9):

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import pyinterp
import pyinterp.fill

n = 30
x = xr.DataArray(np.linspace(0, 2 * np.pi, n), dims=["x"])
y = xr.DataArray(np.linspace(0, 2 * np.pi, n), dims=["y"])
z = np.sin(x) * xr.ones_like(y)

mask = xr.DataArray(np.random.randint(0, 1 + 1, (n, n)).astype("bool"), dims=["x", "y"])
kw = dict(add_colorbar=False)

def interpolate_na(arr):
    x = pyinterp.Axis(arr.x.data)
    y = pyinterp.Axis(arr.x.data)

    z = arr.data
    grid = pyinterp.Grid2D(x, y, z)
    filled = pyinterp.fill.loess(grid, nx=3, ny=3)
    return arr.copy(data=filled)

fig, ax = plt.subplots(1, 4, figsize=(11, 4))
z.plot(ax=ax[0], **kw)
z.where(mask).plot(ax=ax[1], **kw)
z.where(mask).interpolate_na("x").plot(ax=ax[2], **kw)
z.where(mask).pipe(interpolate_na).plot(ax=ax[3], **kw)

It does have a xarray backend, but it looks like that does not allow to customize the coordinate names, it insists on "latitude" and "longitude".

@Huite
Copy link
Contributor

Huite commented Oct 3, 2024

scipy.ndimage.distance_transform_edt is somewhat useful for a nearest implementation:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html

If the sampling argument isn't provided, it'll just look at rows/columns/etc., i.e. equivalent to use_coordinates=False.

def _nearest(a):
    nans = np.isnan(a)
    if not nans.any():
        return a.copy()
    indices = distance_transform_edt(
        input=np.isnan(a),
        return_distances=False,
        return_indices=True,
    )
    return a[tuple(indices)]


def interpolate_na(da, dim, keep_attrs=True):
    arr = xr.apply_ufunc(
        _nearest,
        da,
        input_core_dims=[dim],
        output_core_dims=[dim],
        output_dtypes=[da.dtype],
        dask="parallelized",
        vectorize=True,
        keep_attrs=keep_attrs,
    ).transpose(*da.dims)
    return arr

fig,ax = plt.subplots(1,4,figsize=(14,3))
z.plot(ax=ax[0],**kw)
z.where(mask).plot(ax=ax[1],**kw)
z.where(mask).interpolate_na('x').plot(ax=ax[2],**kw)
interpolate_na(z.where(mask), ["x", "y"]).plot(ax=ax[3],**kw)

image

Interpolating ["x", "y"] versus ["y", "x"] will give different answers; in case a nearest neighbor is one removed, scipy.ndimage.distance_transform_edt will choose the last dimension.

The sampling argument unfortunately only accepts sequence of floats (one for each dimension) so that only works for axis-aligned, equidistant coordinates.

@Huite
Copy link
Contributor

Huite commented Oct 3, 2024

Rioxarray's use of griddata can be made a little easier with apply_ufunc:

def _griddata(arr, xi, method: str):
    ar1d = arr.ravel()
    valid = np.isfinite(ar1d)
    if valid.all():
        return arr
    return griddata(
        points=tuple(x[valid] for x in xi),
        values=ar1d[valid],
        xi=xi,
        method=method,
        fill_value=np.nan,
    ).reshape(arr.shape)


def interpolate_na(da, dim, method="nearest", use_coordinates=True, keep_attrs=True):
    # Create points only once.
    if use_coordinates:
        coords = [da.coords[d] for d in dim]
    else:
        coords = [np.arange(da.sizes[d]) for d in dim]

    xi = tuple(x.ravel() for x in np.meshgrid(*coords, indexing="ij"))
    arr = xr.apply_ufunc(
        _griddata,
        da,
        input_core_dims=[dim],
        output_core_dims=[dim],
        output_dtypes=[da.dtype],
        dask="parallelized",
        vectorize=True,
        keep_attrs=keep_attrs,
        kwargs={"xi": xi, "method": method},
    ).transpose(*da.dims)
    return arr


fig,ax = plt.subplots(1,3,figsize=(11,3))
interpolate_na(z.where(mask), ["y", "x"], method="nearest").plot(ax=ax[0], **kw)
interpolate_na(z.where(mask), ["y", "x"], method="linear").plot(ax=ax[1], **kw)
interpolate_na(z.where(mask), ["y", "x"], method="cubic").plot(ax=ax[2], **kw)

image

griddata would work for non-1D coordinates as well, with a little extra logic.

@Huite
Copy link
Contributor

Huite commented Oct 4, 2024

As a final note: I'm personally quite fond of "Laplace interpolation" (see e.g. chapter 3.8 of Numerical Recipes for the idea):

from scipy import sparse


def _build_connectivity(shape):
    # Get the Cartesian neighbors for a finite difference approximation.
    # TODO: check order of dimensions with DataArray
    size = np.prod(shape)
    index = np.arange(size).reshape(shape)

    # Build nD connectivity
    ii = []
    jj = []
    for d in range(len(shape)):
        slices = [slice(None)] * len(shape)

        slices[d] = slice(None, -1)
        left = index[tuple(slices)].ravel()
        slices[d] = slice(1, None)
        right = index[tuple(slices)].ravel()
        ii.extend([left, right])
        jj.extend([right, left])

    i = np.concatenate(ii)
    j = np.concatenate(jj)
    return sparse.coo_matrix(
        (np.ones(len(i)), (i, j)),
        shape=(size, size)
    ).tocsr()


def _laplace(arr, connectivity):
    ar1d = arr.ravel()
    unknown = np.isnan(ar1d)
    known = ~unknown
    # Set up system of equations
    A = connectivity.copy()
    A.setdiag(-A.sum(axis=1).A[:, 0])
    rhs = -A[:, known].dot(ar1d[known])
    out = ar1d.copy()
    # Linear solve
    out[unknown] = sparse.linalg.spsolve(A[unknown][:, unknown], rhs[unknown])
    return out.reshape(arr.shape)


def interpolate_na_laplace(da, dim, keep_attrs=True):
    shape = tuple(da.sizes[d] for d in dim)
    connectivity = _build_connectivity(shape)
    arr = xr.apply_ufunc(
        _laplace,
        da,
        input_core_dims=[dim],
        output_core_dims=[dim],
        output_dtypes=[da.dtype],
        dask="parallelized",
        vectorize=True,
        keep_attrs=keep_attrs,
        kwargs={"connectivity": connectivity},
    ).transpose(*da.dims)
    return arr

It tends to produce much nicer results if there are "island" shaped gaps, since it'll use all values along the boundary.

The downside is that it's computationally expensive, and for more unknowns ( > 10 000 or so), the direct solve should be replaced by a conjugate-gradient iterative solver... which only works well with a decent preconditioner, which introduces a number of additional settings:

def _laplace(arr, connectivity: sparse.csr_matrix, direct: bool):
    ar1d = arr.ravel()
    unknown = np.isnan(ar1d)
    known = ~unknown

    # Set up system of equations.
    matrix = connectivity.copy()
    matrix.setdiag(-matrix.sum(axis=1).A[:, 0])
    rhs = -matrix[:, known].dot(ar1d[known])

    # Linear solve for the unknowns.
    A = matrix[unknown][:, unknown]
    b = rhs[unknown]
    if direct:
        x = sparse.linalg.spsolve(A, b)
    else:  # Preconditioned conjugate-gradient linear solve.
        # Create preconditioner M
        M = ILU0Preconditioner.from_csr_matrix(A, delta=0.0, relax=0.97)
        # Call conjugate gradient solver
        x, info = sparse.linalg.cg(A, b, rtol=1e-05, atol=0.0, maxiter=1000, M=M)
        if info < 0:
            raise ValueError("scipy.sparse.linalg.cg: illegal input or breakdown")
        elif info > 0:
            warnings.warn(f"Failed to converge after {maxiter} iterations")

    out = ar1d.copy()
    out[unknown] = x
    return out.reshape(arr.shape)

Preconditioner here:

Scipy's spilu works very poorly for some reason, possibly due to how the chosen factorization in the underlying SUPERLU library. https://github.com/c-f-h/ilupp does much better.

The implementation here is a port of the (public domain) Fortran implementation in MODFLOW 6:

from typing import NamedTuple
import numba
import numpy as np
from scipy import sparse


FloatArray = np.ndarray
IntArray = np.ndarray


class MatrixCSR(NamedTuple):
    """
    More or less matches the scipy.sparse.csr_matrix.

    NamedTuple for easy ingestion by numba.
    """

    data: FloatArray
    indices: IntArray
    indptr: IntArray
    n: int
    m: int
    nnz: int

    @staticmethod
    def from_csr_matrix(A: sparse.csr_matrix) -> "MatrixCSR":
        n, m = A.shape
        return MatrixCSR(A.data, A.indices, A.indptr, n, m, A.nnz)


@numba.njit(inline="always")
def nzrange(A: MatrixCSR, row: int) -> range:
    """Return the non-zero indices of a single row."""
    start = A.indptr[row]
    end = A.indptr[row + 1]
    return range(start, end)


@numba.njit(inline="always")
def row_slice(A, row: int) -> slice:
    """Return the indices or data slice of a single row."""
    start = A.indptr[row]
    end = A.indptr[row + 1]
    return slice(start, end)


@numba.njit(inline="always")
def columns_and_values(A, slice):
    return zip(A.indices[slice], A.data[slice])


@numba.njit(inline="always")
def lower_slice(ilu, row: int) -> slice:
    return slice(ilu.indptr[row], ilu.uptr[row])


@numba.njit(inline="always")
def upper_slice(ilu, row: int) -> slice:
    return slice(ilu.uptr[row], ilu.indptr[row + 1])


@numba.njit
def set_uptr(ilu) -> None:
    # i is row index, j is column index
    for i in range(ilu.n):
        for nzi in nzrange(ilu, i):
            j = ilu.indices[nzi]
            if j > i:
                ilu.uptr[i] = nzi
                break
    return


@numba.njit
def _update(ilu, A: MatrixCSR, delta: float, relax: float):
    """
    Perform zero fill-in incomplete lower-upper (ILU0) factorization
    using the values of A.
    """
    ilu.work[:] = 0.0
    visited = np.full(ilu.n, False)

    # i is row index, j is column index, v is value.
    for i in range(ilu.n):
        for j, v in columns_and_values(A, row_slice(A, i)):
            visited[j] = True
            ilu.work[j] += v

        rs = 0.0
        for j in ilu.indices[lower_slice(ilu, i)]:
            # Compute row multiplier
            multiplier = ilu.work[j] * ilu.diagonal[j]
            ilu.work[j] = multiplier
            # Perform linear combination
            for jj, vv in columns_and_values(ilu, upper_slice(ilu, j)):
                if visited[jj]:
                    ilu.work[jj] -= multiplier * vv
                else:
                    rs += multiplier * vv

        diag = ilu.work[i]
        multiplier = (1.0 + delta) * diag - (relax * rs)
        # Work around a zero-valued pivot
        if (np.sign(multiplier) != np.sign(diag)) or (multiplier == 0):
            multiplier = np.sign(diag) * 1.0e-6
        ilu.diagonal[i] = 1.0 / multiplier

        # Reset work arrays, assign off-diagonal values
        visited[i] = False
        ilu.work[i] = 0.0
        for nzi in nzrange(ilu, i):
            j = ilu.indices[nzi]
            ilu.data[nzi] = ilu.work[j]
            ilu.work[j] = 0.0
            visited[j] = False

    return


@numba.njit
def _solve(ilu, r: np.ndarray):
    r"""
    LU \ r

    Stores the result in the pre-allocated work array.
    """
    ilu.work[:] = 0.0

    # forward
    for i in range(ilu.n):
        value = r[i]
        for j, v in columns_and_values(ilu, lower_slice(ilu, i)):
            value -= v * ilu.work[j]
        ilu.work[i] = value

    # backward
    for i in range(ilu.n - 1, -1, -1):
        value = ilu.work[i]
        for j, v in columns_and_values(ilu, upper_slice(ilu, i)):
            value -= v * ilu.work[j]
        ilu.work[i] = value * ilu.diagonal[i]

    return


class ILU0Preconditioner(NamedTuple):
    """
    Preconditioner based on zero fill-in lower-upper (ILU0) factorization.

    Data is stored in compressed sparse row (CSR) format. The diagonal
    values have been extracted for easier access. Upper and lower values
    are stored in CSR format. Next to the indptr array, which identifies
    the start and end of each row, the uptr array has been added to
    identify the start to the right of the diagonal. In case the row to the
    right of the diagonal is empty, it contains the end of the rows as
    indicated by the indptr array.

    Parameters
    ----------
    n: int
        Number of rows
    m: int
        Number of columns
    indptr: np.ndarray of int
        CSR format index pointer array of the matrix
    uptr: np.ndarray of int
        CSR format index pointer array of the upper elements (diagonal or higher)
    indices: np.ndarray of int
        CSR format index array of the matrix
    data: np.ndarray of float
        CSR format data array of the matrix
    diagonal: np.ndarray of float
        Diagonal values of LU factorization
    work: np.ndarray of float
        Work array. Used in factorization and solve.
    """

    n: int
    m: int
    indptr: IntArray
    uptr: IntArray
    indices: IntArray
    data: FloatArray
    diagonal: FloatArray
    work: FloatArray

    @property
    def shape(self) -> tuple[int, int]:
        return (self.n, self.m)

    @property
    def dtype(self):
        return self.data.dtype

    @staticmethod
    def from_csr_matrix(
        A: sparse.csr_matrix, delta: float = 0.0, relax: float = 0.0
    ) -> "ILU0Preconditioner":
        # Create a copy of the sparse matrix with the diagonals removed.
        n, m = A.shape
        coo = A.tocoo()
        i = coo.row
        j = coo.col
        offdiag = i != j
        ii = i[offdiag]
        indices = j[offdiag]
        indptr = sparse.csr_matrix((indices, (ii, indices)), shape=A.shape).indptr

        ilu = ILU0Preconditioner(
            n=n,
            m=m,
            indptr=indptr,
            uptr=indptr[1:].copy(),
            indices=indices,
            data=np.empty(indices.size),
            diagonal=np.empty(n),
            work=np.empty(n),
        )
        set_uptr(ilu)

        _update(ilu, MatrixCSR.from_csr_matrix(A), delta, relax)
        return ilu

    def update(self, A, delta=0.0, relax=0.0) -> None:
        _update(self, MatrixCSR.from_csr_matrix(A), delta, relax)
        return

    def matvec(self, r) -> FloatArray:
        _solve(self, r)
        return self.work

    def __repr__(self) -> str:
        return f"ILU0Preconditioner of type {self.dtype} and shape {self.shape}"

Not the best example, maybe, but to illustrate it does quite well even when data is 99% gap:

from scipy import datasets
import PIL

f = datasets.face()
f_array = np.array(f).astype(float) / 255.0
da = xr.DataArray(f_array, dims=["y", "x", "bands"])
mask = xr.DataArray(np.random.choice([False, True], size=da.shape[:2], p=[0.99, 0.01]), dims=['y','x'])
masked = da.where(mask)

kw = {"yincrease": False}
fig,ax = plt.subplots(2,3,figsize=(11,7))
da.plot.imshow(ax=ax[0, 0],**kw)
masked.plot.imshow(ax=ax[0, 1],**kw)
interpolate_na_laplace(masked, ["y", "x"]).plot.imshow(ax=ax[0, 2],**kw)
interpolate_na(masked, ["y", "x"], method="nearest").plot.imshow(ax=ax[1, 0], **kw)
interpolate_na(masked, ["y", "x"], method="linear").plot.imshow(ax=ax[1, 1], **kw)
interpolate_na(masked, ["y", "x"], method="cubic").plot.imshow(ax=ax[1, 2], **kw)

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

9 participants