phausamann / sklearn-xarray

Metadata-aware machine learning.
http://bit.do/sklearn-xarray
BSD 3-Clause "New" or "Revised" License
102 stars 12 forks source link

preprocessing.Splitter and Segmenter: allow reduce_index to be a callable. #1

Open phausamann opened 6 years ago

phausamann commented 6 years ago

It should be possible to pass a callable as reduce_index which takes the number of samples after transformation as an input and returns an index of this length.

charlesbmi commented 2 years ago

@phausamann, it would also be really useful to enable Selector to select based on a criteria/parameter (e.g., if time is a dimension, then selecting data where time < 1). This would be helpful, for example, to test which time range is the most predictive of the target label.

Is there a way to do this? Thanks!

phausamann commented 2 years ago

Hi @charlesincharge, that should be pretty easy to implement by adding a new parameter to Selector, something like mode="lt" (less than) in your case.

As a workaround, you could add an extra boolean coordinate to you DataArray/Dataset, like this:

da.coords["time_selector"] = da.time < 1

and then use the Selector based on that coordinate.

Unless you want to do this dynamically in a pipeline, in which case the workaround probably won't work.

charlesbmi commented 2 years ago

Thanks for the suggestion @phausamann .

I am actually interested in doing this dynamically in a pipeline. I'm not sure how common this is in other fields, but this is useful in neural decoding: e.g., "how long after the stimulus presentation can we decode it from the neural activity?"

I ended up adapting Selector slightly for my own purposes, adding a select_func parameter:

class Selector(BaseTransformer):
    """Improved version of sklearn-xarray `Selector` that has an adjustable
    selection parameter, based on the coordinate value.
    select_func: function that takes single values or 1-D arrays,
        along with a parameter `param`, then returns boolean
    param: extra parameter to the select function
    """

    def __init__(self, dim='sample', coord=None, select_func=None, param=None):
        if coord is None:
            raise ValueError('coord must be specified.')
        self.dim = dim
        self.coord = coord
        self.select_func = select_func
        self.param = param
        self.groupby = None

    def _transform(self, X):
        """Transform
        Parameters
        ----------
        X: xr.DataArray
        """
        X_coord = X[self.coord]
        mask = self.select_func(X_coord, param=self.param)
        X_transform = X.isel({self.dim: mask})
        return X_transform

    def _inverse_transform(self, X):
        raise NotImplementedError

Then passing in a less-than function as selection (with select_window__param set by the cross-validation pipeline):

    def lt_eq(arr, param):
        """Helper function for less-than-or-equal-to"""
        return arr <= param

    return Pipeline(
        [
            (
                'select_window',
                Selector(dim='time_bin', coord='time_bin', select_func=lt_eq),
            ),
...

Not sure if this would be useful for general-purpose sklearn-xarray?