jgrss / geowombat

GeoWombat: Utilities for geospatial data
https://geowombat.readthedocs.io
MIT License
182 stars 10 forks source link

Apply interpolation #286

Closed mmann1123 closed 9 months ago

mmann1123 commented 9 months ago

@jgrss I have been working on alternative ways to interpolate missing values on a time series. The apply method initially seemed like a good option. However it seems to only be able to write a single date back out. Hence the use of return array[self.index_to_write].squeeze().

Just wondering if there is another way I could apply the function as quickly. Or if I need to use something like map_blocks instead (which is much slower I think).

# not working because I can't write out multiple observations
def _interpolate_nans_linear(array):
    if all(np.isnan(array)):
        return array
    else:
        return np.interp(
            np.arange(len(array)),
            np.arange(len(array))[np.isnan(array) == False],
            array[np.isnan(array) == False],
        )

class interpolate_nan(gw.TimeModule):
    """Interpolate missing values in the time series using linear interpolation.

    Args:
        gw (_type_): _description_
        missing_value (int, optional): The value to be replaced by NaNs. Default is None.
        interp_type (str, optional): The type of interpolation to use. Default is "linear".
        index_to_write (int, optional): The index of the interpolated array to return. Default is 0.

    """

    def __init__(self, missing_value=None, interp_type="linear", index_to_write=0):
        super(interpolate_nan, self).__init__()
        self.missing_value = missing_value
        self.interp_type = interp_type
        self.index_to_write = index_to_write

    def calculate(self, array):
        # check if missing_value is not None and not np.nan
        if self.missing_value is not None:
            if not np.isnan(self.missing_value):
                array = jnp.where(array == self.missing_value, np.NaN, array)
            if self.interp_type == "linear":
                array = np.apply_along_axis(_interpolate_nans_linear, axis=0, arr=array)
        # Return one of the interpolated arrays base on the index_to_write
        return array[self.index_to_write].squeeze()  #

with gw.series(
    files,
    nodata=9999,
) as src:
    src.apply(
        func=interpolate_nan(missing_value=0),
        outfile=f"/home/mmann1123/Downloads/test.tif",
        num_workers=5,
        bands=1,
    )
jgrss commented 9 months ago

Hey @mmann1123 I think I follow your setup, except for the shape of your data. What is the shape of array[self.index_to_write]? Is array 4d (time x bands x height x width) and then you slice to get (1 x bands (1?) x height x width), and then squeeze to (time x height x width)?

You can control a bit of the output profile in your user function. For example, the output band count is set by the TimeModule.count (by default, the output band count is 1). And that gets passed here to the rasterio profile.

If you want a multi-band output then you can specify that in your user function. I don't think you need index_to_write, so I replaced it with count below in your __init__ method and in the return. But I think that assumes you are processing multi-temporal, single band data. See my comment at the bottom.

class interpolate_nan(gw.TimeModule):
    def __init__(self, missing_value=None, interp_type="linear", count=1):
        super(interpolate_nan, self).__init__()
        self.missing_value = missing_value
        self.interp_type = interp_type
        # Overrides the default output band count
        self.count = count

    def calculate(self, array):
        # check if missing_value is not None and not np.nan
        if self.missing_value is not None:
            if not np.isnan(self.missing_value):
                array = jnp.where(array == self.missing_value, np.NaN, array)
            if self.interp_type == "linear":
                array = np.apply_along_axis(_interpolate_nans_linear, axis=0, arr=array)
        # Return the interpolated array (3d -> time/bands x height x width)
        # If the array is (time x 1 x height x width) then squeeze to 3d
        return array.squeeze()

Then, you should be able to use it by:

with gw.series(
    files,
    nodata=9999,
) as src:
    src.apply(
        func=interpolate_nan(
            missing_value=0, 
            # not sure if your output length matches your input file length
            # whatever your case is, this is where you define the output band count
            count=len(src.filenames)
        ),
        outfile=f"/home/mmann1123/Downloads/test.tif",
        num_workers=5,
        # Note that this is the band, or bands, to read
        bands=1,
    )

Note that you can only write a 3d array. Therefore, you can either write a single interpolated date and multiple bands, or all the dates for a single band.

jgrss commented 9 months ago

Did you also try xarray's interpolate method?

with gw.open(files, chunks={'time': -1}) as src:
    interp = (
        src.interpolate_na(dim='time', method='linear', fill_value='extrapolate')
        .bfill(dim='time')
        .ffill(dim='time')
        # Interpolate to new grid
        #.interp(time=smooth_range, method='slinear')
    )
mmann1123 commented 9 months ago

I had been using xarrays interpolate_na but now I have files that are too big to bring into memory, and I am not sure how to apply it to chunks and write it out.

mmann1123 commented 9 months ago

Ok you are a life saver as always. Yes I was applying the interpolation to a single band for multiple periods. Your apply example worked! Thanks again