lemma-osu / sknnr-spatial

https://sknnr-spatial.readthedocs.io
0 stars 0 forks source link

Support spatial `transform` in estimators that inherit from `TransformerMixin` #16

Open grovduck opened 4 months ago

grovduck commented 4 months ago

For some use cases, it will be valuable to return the spatial representation of an estimator's transform method where that estimator inherits from sklearn's TransformerMixin (for example, sklearn's PCA or sknnr's CCATransformer).

As of #12, sknnr-spatial supports predict and kneighbors in a functional context (soon to be implemented as an estimator wrapper in #13). Supporting transform will likely be an extension of this logic. Based on some initial crude experimentation, the following code implements transform using a functional API. This code is not fully tested, but introducing here to keep a record of what was done.

src/sknnr-spatial/image/_base.py

@singledispatch
def transform(
    X_image: NDArray | xr.DataArray | xr.Dataset,
    *,
    estimator: BaseEstimator,
    nodata_vals=None,
) -> None:
    msg = f"transform is not implemented for type `{X_image.__class__.__name__}`."
    raise NotImplementedError(msg)

src/sknnr-spatial/image/ndarray.py

@transform.register(np.ndarray)
def _transform_from_ndarray(
    X_image: NDArray,
    *,
    estimator: KNeighborsRegressor,
    nodata_vals=None,
    **kneighbors_kwargs,
) -> NDArray:
    check_is_fitted(estimator)
    preprocessor = NDArrayPreprocessor(X_image, nodata_vals=nodata_vals)

    # TODO: Deal with sklearn warning about missing feature names
    y_pred_flat = estimator.transform(preprocessor.flat)

    return preprocessor.unflatten(y_pred_flat, apply_mask=True)

src/sknnr-spatial/image/_dask_backed.py

def transform_from_dask_backed_array(
    X_image: DaskBackedType,
    *,
    estimator: BaseEstimator,
    y=None,
    preprocessor_cls: type[DataArrayPreprocessor] | type[DatasetPreprocessor],
    nodata_vals=None,
) -> DaskBackedType:
    """Generic transform wrapper for Dask-backed arrays."""
    check_is_fitted(estimator)
    preprocessor = preprocessor_cls(X_image, nodata_vals=nodata_vals)

    # HACK: Using get_features_names_out() to infer the number of targets
    # I don't think this is guaranteed to work for all transformers
    target_names = estimator.get_feature_names_out()
    n_targets = len(target_names)

    y_transform = da.apply_gufunc(
        estimator.transform,
        "(x)->(y)",
        preprocessor.flat,
        axis=preprocessor.flat_band_dim,
        output_dtypes=[float],
        output_sizes={"y": n_targets},
        allow_rechunk=True,
    )

    return preprocessor.unflatten(y_transform, var_names=target_names)

src/sknnr-spatial/image/dataarray.py

@transform.register(xr.DataArray)
def _transform_from_dataarray(
    X_image: xr.DataArray, *, estimator: BaseEstimator, y=None, nodata_vals=None
) -> xr.DataArray:
    return transform_from_dask_backed_array(
        X_image,
        estimator=estimator,
        y=y,
        nodata_vals=nodata_vals,
        preprocessor_cls=DataArrayPreprocessor,
    )

src/sknnr-spatial/image/dataset.py

@transform.register(xr.Dataset)
def _transform_from_dataset(
    X_image: xr.Dataset, *, estimator: BaseEstimator, y=None, nodata_vals=None
) -> xr.Dataset:
    return transform_from_dask_backed_array(
        X_image,
        estimator=estimator,
        y=y,
        nodata_vals=nodata_vals,
        preprocessor_cls=DatasetPreprocessor,
    )
grovduck commented 4 months ago

@aazuspan, please feel free to modify the issue description. I just wanted to capture what I had done to make synthetic kNN work with the transform step. I'll stash these changes locally, so I can test your last changes that you've made in #12.

aazuspan commented 4 months ago

Looks great @grovduck, thanks!

I just did a quick check and it looks like nearly all transformers implement get_feature_names_out. Maybe we just throw a NotImplementedError for any that don't? For reference:

from sklearn.utils.discovery import all_estimators

for name, trans in all_estimators("transformer"):
    if not hasattr(trans, "get_feature_names_out"):
        print(name)

"""
FeatureHasher
HashingVectorizer
LabelBinarizer
LabelEncoder
MultiLabelBinarizer
PatchExtractor
"""

At a glance, none of these seem relevant to applying in 2D space.

grovduck commented 4 months ago

Maybe we just throw a NotImplementedError for any that don't?

That definitely seems like a reasonable approach to me - those transformers don't strike me as any that would be necessary. The flip side would be whether those estimators/transformers that do implement get_feature_names_out make sense to transform into 2D space? Looks like there are 84 of these (not counting sknnr transformers).

aazuspan commented 4 months ago

The flip side would be whether those estimators/transformers that do implement get_feature_names_out make sense to transform into 2D space? Looks like there are 84 of these (not counting sknnr transformers).

Yeah, most of those 84 probably would never be used, but I suppose if they follow a consistent protocol and we can implement them all in one go, there's no harm (maybe with a disclaimer that we only explicitly test/support a subset of them?). One thing we'll need to watch out for is any estimator/transformer that modifies the spatial shape, i.e. that returns more or fewer samples than it was fit with, since that would break the Dask side. I'm not aware of anything like that, but I've probably never touched 90% of the functionality in sklearn, so I wouldn't be shocked.

I'm pretty excited for this feature - being able to run PCA or StandardScaler on images seamlessly will be great!

grovduck commented 4 months ago

One thing we'll need to watch out for is any estimator/transformer that modifies the spatial shape, i.e. that returns more or fewer samples than it was fit with, since that would break the Dask side. I'm not aware of anything like that, but I've probably never touched 90% of the functionality in sklearn, so I wouldn't be shocked.

Good call. I recognized only a handful of those transformers as well, so I'm unfamiliar if any would modify the shape, but definitely something to watch out for.

I'm pretty excited for this feature - being able to run PCA or StandardScaler on images seamlessly will be great!

Absolutely!

Just to be clear, your plan is to tackle #13 before handling this one, correct? Please let me know if there are "side jobs" on either of these issues that you'd like my help with (other than reviews).

aazuspan commented 4 months ago

Just to be clear, your plan is to tackle https://github.com/lemma-osu/sknnr-spatial/issues/13 before handling this one, correct? Please let me know if there are "side jobs" on either of these issues that you'd like my help with (other than reviews).

Yes, I made some changes in #18 that should hopefully reduce some code duplication when adding new methods like this (although that means your implementation will need to be refactored a little bit to match, sorry!). If you want to tackle this issue once #18 is merged, that would be great! You've got a better idea of the real-world use cases for this, and you've already figured out the tricky part of getting the output shape.