pydata / xarray

N-D labeled arrays and datasets in Python
https://xarray.dev
Apache License 2.0
3.62k stars 1.08k forks source link

Multidimensional `interpolate_na()` #6360

Open iuryt opened 2 years ago

iuryt commented 2 years ago

Is your feature request related to a problem?

I think that having a way to run a multidimensional interpolation for filling missing values would be awesome.

The code snippet below create a data and show the problem I am having now. If the data has some orientation, we couldn't simply interpolate dimensions separately.

import xarray as xr
import numpy as np

n = 30
x = xr.DataArray(np.linspace(0,2*np.pi,n),dims=['x'])
y = xr.DataArray(np.linspace(0,2*np.pi,n),dims=['y'])
z = (np.sin(x)*xr.ones_like(y))

mask = xr.DataArray(np.random.randint(0,1+1,(n,n)).astype('bool'),dims=['x','y'])

kw = dict(add_colorbar=False)

fig,ax = plt.subplots(1,3,figsize=(11,3))
z.plot(ax=ax[0],**kw)
z.where(mask).plot(ax=ax[1],**kw)
z.where(mask).interpolate_na('x').plot(ax=ax[2],**kw)

image

I tried to use advanced interpolation for that, but it doesn't look like the best solution.

zs = z.where(mask).stack(k=['x','y'])
zs = zs.where(np.isnan(zs),drop=True)
xi,yi = zs.k.x.drop('k'),zs.k.y.drop('k')
zi = z.interp(x=xi,y=yi)

fig,ax = plt.subplots()
z.where(mask).plot(ax=ax,**kw)
ax.scatter(xi,yi,c=zi,**kw,linewidth=1,edgecolor='k')

returns

image

Describe the solution you'd like

Simply z.interpolate_na(['x','y'])

Describe alternatives you've considered

I could extract the data to numpy and interpolate using scipy.interpolate.griddata, but this is not the way xarray should work.

Additional context

No response

hafez-ahmad commented 2 years ago

image I am not getting proper plot. It is okay with Arcgis 10.5 , when I am trying xarray, plot looks many missing data or gridded points. Data source: https://giovanni.gsfc.nasa.gov/session/B4259246-EB9D-11EA-A3A4-16015F835E51/D2BACBDC-CE49-11EC-9A98-0352ECEEFA7B/D2BAE400-CE49-11EC-9A98-0352ECEEFA7B///scrubbed.MODISA_L3m_ZLEE_2018_Zeu_lee.20140101.nc https://giovanni.gsfc.nasa.gov/session/B4259246-EB9D-11EA-A3A4-16015F835E51/D2BACBDC-CE49-11EC-9A98-0352ECEEFA7B/D2BAE400-CE49-11EC-9A98-0352ECEEFA7B///scrubbed.MODISA_L3m_ZLEE_2018_Zeu_lee.20150101.nc

faceted plot by month

g=data_mean.MODISA_L3m_ZLEE_2018_Zeu_lee.plot(x='lon',y='lat',col='month',col_wrap=4,cmap='RdBu_r',subplot_kws={ "projection": ccrs.Robinson()},figsize=(20,20)) for i, ax in enumerate(g.axes.flat): ax.set_title(data_mean.month.values[i]) ax.coastlines() ax.add_feature(cfeature.BORDERS.with_scale('50m'), linewidth=0.5, edgecolor='black') ax.gridlines(crs=ccrs.PlateCarree(), linewidth=0.5, linestyle='-')

Would you please help me out with why I am not getting the proper surface?

Thank you

dcherian commented 2 years ago

@hafez-ahmad please open a new discussion topic with a fully reproducible example.

thomas-fred commented 1 year ago

I'd also find this very useful

TheJeran commented 1 year ago

Bumping

martin-wegmann commented 4 months ago

This would be super useful!

albertotb commented 4 months ago

+1. As an alternative I think interpolate_na from rioxarray supports this: https://corteva.github.io/rioxarray/html/examples/interpolate_na.html

keewis commented 3 months ago

you might be interested in pyinterp. With some extreme tuning, this can even reconstruct the original image (set nx=1, ny=9):

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import pyinterp
import pyinterp.fill

n = 30
x = xr.DataArray(np.linspace(0, 2 * np.pi, n), dims=["x"])
y = xr.DataArray(np.linspace(0, 2 * np.pi, n), dims=["y"])
z = np.sin(x) * xr.ones_like(y)

mask = xr.DataArray(np.random.randint(0, 1 + 1, (n, n)).astype("bool"), dims=["x", "y"])
kw = dict(add_colorbar=False)

def interpolate_na(arr):
    x = pyinterp.Axis(arr.x.data)
    y = pyinterp.Axis(arr.x.data)

    z = arr.data
    grid = pyinterp.Grid2D(x, y, z)
    filled = pyinterp.fill.loess(grid, nx=3, ny=3)
    return arr.copy(data=filled)

fig, ax = plt.subplots(1, 4, figsize=(11, 4))
z.plot(ax=ax[0], **kw)
z.where(mask).plot(ax=ax[1], **kw)
z.where(mask).interpolate_na("x").plot(ax=ax[2], **kw)
z.where(mask).pipe(interpolate_na).plot(ax=ax[3], **kw)

It does have a xarray backend, but it looks like that does not allow to customize the coordinate names, it insists on "latitude" and "longitude".

Huite commented 1 month ago

scipy.ndimage.distance_transform_edt is somewhat useful for a nearest implementation: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html

If the sampling argument isn't provided, it'll just look at rows/columns/etc., i.e. equivalent to use_coordinates=False.

def _nearest(a):
    nans = np.isnan(a)
    if not nans.any():
        return a.copy()
    indices = distance_transform_edt(
        input=np.isnan(a),
        return_distances=False,
        return_indices=True,
    )
    return a[tuple(indices)]

def interpolate_na(da, dim, keep_attrs=True):
    arr = xr.apply_ufunc(
        _nearest,
        da,
        input_core_dims=[dim],
        output_core_dims=[dim],
        output_dtypes=[da.dtype],
        dask="parallelized",
        vectorize=True,
        keep_attrs=keep_attrs,
    ).transpose(*da.dims)
    return arr

fig,ax = plt.subplots(1,4,figsize=(14,3))
z.plot(ax=ax[0],**kw)
z.where(mask).plot(ax=ax[1],**kw)
z.where(mask).interpolate_na('x').plot(ax=ax[2],**kw)
interpolate_na(z.where(mask), ["x", "y"]).plot(ax=ax[3],**kw)

image

Interpolating ["x", "y"] versus ["y", "x"] will give different answers; in case a nearest neighbor is one removed, scipy.ndimage.distance_transform_edt will choose the last dimension.

The sampling argument unfortunately only accepts sequence of floats (one for each dimension) so that only works for axis-aligned, equidistant coordinates.

Huite commented 1 month ago

Rioxarray's use of griddata can be made a little easier with apply_ufunc:


def _griddata(arr, xi, method: str):
    ar1d = arr.ravel()
    valid = np.isfinite(ar1d)
    if valid.all():
        return arr
    return griddata(
        points=tuple(x[valid] for x in xi),
        values=ar1d[valid],
        xi=xi,
        method=method,
        fill_value=np.nan,
    ).reshape(arr.shape)

def interpolate_na(da, dim, method="nearest", use_coordinates=True, keep_attrs=True):
    # Create points only once.
    if use_coordinates:
        coords = [da.coords[d] for d in dim]
    else:
        coords = [np.arange(da.sizes[d]) for d in dim]

    xi = tuple(x.ravel() for x in np.meshgrid(*coords, indexing="ij"))
    arr = xr.apply_ufunc(
        _griddata,
        da,
        input_core_dims=[dim],
        output_core_dims=[dim],
        output_dtypes=[da.dtype],
        dask="parallelized",
        vectorize=True,
        keep_attrs=keep_attrs,
        kwargs={"xi": xi, "method": method},
    ).transpose(*da.dims)
    return arr

fig,ax = plt.subplots(1,3,figsize=(11,3))
interpolate_na(z.where(mask), ["y", "x"], method="nearest").plot(ax=ax[0], **kw)
interpolate_na(z.where(mask), ["y", "x"], method="linear").plot(ax=ax[1], **kw)
interpolate_na(z.where(mask), ["y", "x"], method="cubic").plot(ax=ax[2], **kw)

image

griddata would work for non-1D coordinates as well, with a little extra logic.

Huite commented 1 month ago

As a final note: I'm personally quite fond of "Laplace interpolation" (see e.g. chapter 3.8 of Numerical Recipes for the idea):

from scipy import sparse

def _build_connectivity(shape):
    # Get the Cartesian neighbors for a finite difference approximation.
    # TODO: check order of dimensions with DataArray
    size = np.prod(shape)
    index = np.arange(size).reshape(shape)

    # Build nD connectivity
    ii = []
    jj = []
    for d in range(len(shape)):
        slices = [slice(None)] * len(shape)

        slices[d] = slice(None, -1)
        left = index[tuple(slices)].ravel()
        slices[d] = slice(1, None)
        right = index[tuple(slices)].ravel()
        ii.extend([left, right])
        jj.extend([right, left])

    i = np.concatenate(ii)
    j = np.concatenate(jj)
    return sparse.coo_matrix(
        (np.ones(len(i)), (i, j)),
        shape=(size, size)
    ).tocsr()

def _laplace(arr, connectivity):
    ar1d = arr.ravel()
    unknown = np.isnan(ar1d)
    known = ~unknown
    # Set up system of equations
    A = connectivity.copy()
    A.setdiag(-A.sum(axis=1).A[:, 0])
    rhs = -A[:, known].dot(ar1d[known])
    out = ar1d.copy()
    # Linear solve
    out[unknown] = sparse.linalg.spsolve(A[unknown][:, unknown], rhs[unknown])
    return out.reshape(arr.shape)

def interpolate_na_laplace(da, dim, keep_attrs=True):
    shape = tuple(da.sizes[d] for d in dim)
    connectivity = _build_connectivity(shape)
    arr = xr.apply_ufunc(
        _laplace,
        da,
        input_core_dims=[dim],
        output_core_dims=[dim],
        output_dtypes=[da.dtype],
        dask="parallelized",
        vectorize=True,
        keep_attrs=keep_attrs,
        kwargs={"connectivity": connectivity},
    ).transpose(*da.dims)
    return arr

It tends to produce much nicer results if there are "island" shaped gaps, since it'll use all values along the boundary.

The downside is that it's computationally expensive, and for more unknowns ( > 10 000 or so), the direct solve should be replaced by a conjugate-gradient iterative solver... which only works well with a decent preconditioner, which introduces a number of additional settings:

def _laplace(arr, connectivity: sparse.csr_matrix, direct: bool):
    ar1d = arr.ravel()
    unknown = np.isnan(ar1d)
    known = ~unknown

    # Set up system of equations.
    matrix = connectivity.copy()
    matrix.setdiag(-matrix.sum(axis=1).A[:, 0])
    rhs = -matrix[:, known].dot(ar1d[known])

    # Linear solve for the unknowns.
    A = matrix[unknown][:, unknown]
    b = rhs[unknown]
    if direct:
        x = sparse.linalg.spsolve(A, b)
    else:  # Preconditioned conjugate-gradient linear solve.
        # Create preconditioner M
        M = ILU0Preconditioner.from_csr_matrix(A, delta=0.0, relax=0.97)
        # Call conjugate gradient solver
        x, info = sparse.linalg.cg(A, b, rtol=1e-05, atol=0.0, maxiter=1000, M=M)
        if info < 0:
            raise ValueError("scipy.sparse.linalg.cg: illegal input or breakdown")
        elif info > 0:
            warnings.warn(f"Failed to converge after {maxiter} iterations")

    out = ar1d.copy()
    out[unknown] = x
    return out.reshape(arr.shape)

Preconditioner here:

Scipy's spilu works very poorly for some reason, possibly due to how the chosen factorization in the underlying SUPERLU library. https://github.com/c-f-h/ilupp does much better. The implementation here is a port of the (public domain) Fortran implementation in MODFLOW 6: ```python from typing import NamedTuple import numba import numpy as np from scipy import sparse FloatArray = np.ndarray IntArray = np.ndarray class MatrixCSR(NamedTuple): """ More or less matches the scipy.sparse.csr_matrix. NamedTuple for easy ingestion by numba. """ data: FloatArray indices: IntArray indptr: IntArray n: int m: int nnz: int @staticmethod def from_csr_matrix(A: sparse.csr_matrix) -> "MatrixCSR": n, m = A.shape return MatrixCSR(A.data, A.indices, A.indptr, n, m, A.nnz) @numba.njit(inline="always") def nzrange(A: MatrixCSR, row: int) -> range: """Return the non-zero indices of a single row.""" start = A.indptr[row] end = A.indptr[row + 1] return range(start, end) @numba.njit(inline="always") def row_slice(A, row: int) -> slice: """Return the indices or data slice of a single row.""" start = A.indptr[row] end = A.indptr[row + 1] return slice(start, end) @numba.njit(inline="always") def columns_and_values(A, slice): return zip(A.indices[slice], A.data[slice]) @numba.njit(inline="always") def lower_slice(ilu, row: int) -> slice: return slice(ilu.indptr[row], ilu.uptr[row]) @numba.njit(inline="always") def upper_slice(ilu, row: int) -> slice: return slice(ilu.uptr[row], ilu.indptr[row + 1]) @numba.njit def set_uptr(ilu) -> None: # i is row index, j is column index for i in range(ilu.n): for nzi in nzrange(ilu, i): j = ilu.indices[nzi] if j > i: ilu.uptr[i] = nzi break return @numba.njit def _update(ilu, A: MatrixCSR, delta: float, relax: float): """ Perform zero fill-in incomplete lower-upper (ILU0) factorization using the values of A. """ ilu.work[:] = 0.0 visited = np.full(ilu.n, False) # i is row index, j is column index, v is value. for i in range(ilu.n): for j, v in columns_and_values(A, row_slice(A, i)): visited[j] = True ilu.work[j] += v rs = 0.0 for j in ilu.indices[lower_slice(ilu, i)]: # Compute row multiplier multiplier = ilu.work[j] * ilu.diagonal[j] ilu.work[j] = multiplier # Perform linear combination for jj, vv in columns_and_values(ilu, upper_slice(ilu, j)): if visited[jj]: ilu.work[jj] -= multiplier * vv else: rs += multiplier * vv diag = ilu.work[i] multiplier = (1.0 + delta) * diag - (relax * rs) # Work around a zero-valued pivot if (np.sign(multiplier) != np.sign(diag)) or (multiplier == 0): multiplier = np.sign(diag) * 1.0e-6 ilu.diagonal[i] = 1.0 / multiplier # Reset work arrays, assign off-diagonal values visited[i] = False ilu.work[i] = 0.0 for nzi in nzrange(ilu, i): j = ilu.indices[nzi] ilu.data[nzi] = ilu.work[j] ilu.work[j] = 0.0 visited[j] = False return @numba.njit def _solve(ilu, r: np.ndarray): r""" LU \ r Stores the result in the pre-allocated work array. """ ilu.work[:] = 0.0 # forward for i in range(ilu.n): value = r[i] for j, v in columns_and_values(ilu, lower_slice(ilu, i)): value -= v * ilu.work[j] ilu.work[i] = value # backward for i in range(ilu.n - 1, -1, -1): value = ilu.work[i] for j, v in columns_and_values(ilu, upper_slice(ilu, i)): value -= v * ilu.work[j] ilu.work[i] = value * ilu.diagonal[i] return class ILU0Preconditioner(NamedTuple): """ Preconditioner based on zero fill-in lower-upper (ILU0) factorization. Data is stored in compressed sparse row (CSR) format. The diagonal values have been extracted for easier access. Upper and lower values are stored in CSR format. Next to the indptr array, which identifies the start and end of each row, the uptr array has been added to identify the start to the right of the diagonal. In case the row to the right of the diagonal is empty, it contains the end of the rows as indicated by the indptr array. Parameters ---------- n: int Number of rows m: int Number of columns indptr: np.ndarray of int CSR format index pointer array of the matrix uptr: np.ndarray of int CSR format index pointer array of the upper elements (diagonal or higher) indices: np.ndarray of int CSR format index array of the matrix data: np.ndarray of float CSR format data array of the matrix diagonal: np.ndarray of float Diagonal values of LU factorization work: np.ndarray of float Work array. Used in factorization and solve. """ n: int m: int indptr: IntArray uptr: IntArray indices: IntArray data: FloatArray diagonal: FloatArray work: FloatArray @property def shape(self) -> tuple[int, int]: return (self.n, self.m) @property def dtype(self): return self.data.dtype @staticmethod def from_csr_matrix( A: sparse.csr_matrix, delta: float = 0.0, relax: float = 0.0 ) -> "ILU0Preconditioner": # Create a copy of the sparse matrix with the diagonals removed. n, m = A.shape coo = A.tocoo() i = coo.row j = coo.col offdiag = i != j ii = i[offdiag] indices = j[offdiag] indptr = sparse.csr_matrix((indices, (ii, indices)), shape=A.shape).indptr ilu = ILU0Preconditioner( n=n, m=m, indptr=indptr, uptr=indptr[1:].copy(), indices=indices, data=np.empty(indices.size), diagonal=np.empty(n), work=np.empty(n), ) set_uptr(ilu) _update(ilu, MatrixCSR.from_csr_matrix(A), delta, relax) return ilu def update(self, A, delta=0.0, relax=0.0) -> None: _update(self, MatrixCSR.from_csr_matrix(A), delta, relax) return def matvec(self, r) -> FloatArray: _solve(self, r) return self.work def __repr__(self) -> str: return f"ILU0Preconditioner of type {self.dtype} and shape {self.shape}" ```

Not the best example, maybe, but to illustrate it does quite well even when data is 99% gap:

from scipy import datasets
import PIL

f = datasets.face()
f_array = np.array(f).astype(float) / 255.0
da = xr.DataArray(f_array, dims=["y", "x", "bands"])
mask = xr.DataArray(np.random.choice([False, True], size=da.shape[:2], p=[0.99, 0.01]), dims=['y','x'])
masked = da.where(mask)

kw = {"yincrease": False}
fig,ax = plt.subplots(2,3,figsize=(11,7))
da.plot.imshow(ax=ax[0, 0],**kw)
masked.plot.imshow(ax=ax[0, 1],**kw)
interpolate_na_laplace(masked, ["y", "x"]).plot.imshow(ax=ax[0, 2],**kw)
interpolate_na(masked, ["y", "x"], method="nearest").plot.imshow(ax=ax[1, 0], **kw)
interpolate_na(masked, ["y", "x"], method="linear").plot.imshow(ax=ax[1, 1], **kw)
interpolate_na(masked, ["y", "x"], method="cubic").plot.imshow(ax=ax[1, 2], **kw)

image