intake / intake-xarray

Intake plugin for xarray
https://intake-xarray.readthedocs.io/
BSD 2-Clause "Simplified" License
74 stars 36 forks source link

sel transform #118

Closed raybellwaves closed 2 years ago

raybellwaves commented 2 years ago

Convo started at https://github.com/intake/intake/pull/663

Ready for review

I believe I have this working outside of the repo.

my_dervied.py:

from intake import Schema
from intake.source.derived import GenericTransform

class XArrayTransform(GenericTransform):
    """Transform where the input and output are both xarray objects.
    You must supply ``transform`` and any ``transform_kwargs``.
    """

    input_container = "xarray"
    container = "xarray"
    optional_params = {}
    _ds = None

    def to_dask(self):
        if self._ds is None:
            self._pick()
            self._ds = self._transform(
                self._source.to_dask(), **self._params["transform_kwargs"]
            )
        return self._ds

    def _get_schema(self):
        """load metadata only if needed"""
        self.to_dask()
        return Schema(
            datashape=None,
            dtype=None,
            shape=None,
            npartitions=None,
            extra_metadata=self._ds.extra_metadata,
        )

    def read(self):
        return self.to_dask().compute()

class Sel(XArrayTransform):
    """Simple array transform to subsample an xarray object using
    the sel method.
    Note that you could use XArrayTransform directly, by writing a
    function to choose the subsample instead of a method as here.
    """

    input_container = "xarray"
    container = "xarray"
    required_params = ["indexers"]

    def __init__(self, indexers, **kwargs):
        """
        indexers: dict (stord as str) which is passed to xarray.Dataset.sel
        """
        # this class wants required "indexers", but XArrayTransform
        # uses "transform_kwargs", which we don't need since we use a method for the
        # transform
        kwargs.update(
            transform=self.sel,
            indexers=indexers,
            transform_kwargs={},
        )
        super().__init__(**kwargs)

    def sel(self, ds):
        return ds.sel(eval(self._params["indexers"]))

def _sel(ds, indexers: str):
    """indexers: dict (stored as str) which is passed to xarray.Dataset.sel"""
    return ds.sel(eval(indexers))

test.yml:

metadata:
  version: 1
sources:
  xarray_source:
    description: example xarray source plugin
    driver: netcdf
    args:
      urlpath: 'example_1.nc'
      chunks: {}

  xarray_source_sel:
    description: select subsample of xarray_source entry
    driver: my_derived.XArrayTransform
    args:
      targets:
        - xarray_source
      transform: "my_derived._sel"
      transform_kwargs:
        indexers: "dict([('lat', 20)])"
Screen Shot 2022-05-22 at 11 43 00 PM
martindurant commented 2 years ago

I think you are right and this fits well here. Unfortunately, there's a couple of things to fix for the upstream tests (as you have found previously). The attribute error seems simple enough, not sure about the import error. Do you think you can fix these here?

martindurant commented 2 years ago

OK, I am totally happy to merge this (sorry for taking so long), and I'll be sure to getting around to fixing the upstream build.