giotto-ai / giotto-tda

A high-performance topological machine learning toolbox in Python
https://giotto-ai.github.io/gtda-docs
Other
845 stars 173 forks source link

Proposal: Metaestimators making any fit-transformer act on collections #490

Closed ulupo closed 3 years ago

ulupo commented 4 years ago

Also, because I need to run PCA on the output of the Takens embedding, I ended up not using the TakensEmbedding transformer after all because it produces 3D output from the collection and PCA wants 2D 😢.

Originally posted by @lewtun in https://github.com/giotto-ai/giotto-tda/pull/458#issuecomment-691297711

I think the issue encountered by @lewtun here is too bad and it shows that we should add some utilities for the user to be able to promote scikit-learn transformers (which act on 2D data) to transformers acting on collections of 2D arrays.

It would be quite easy to provide this sort of functionality to the user, and I can see at least two ways.

A metaestimator taking a transformer instance as argument

The idea here is to provide a metaestimator class in a style similar to this example in scikit-learn: https://scikit-learn.org/stable/auto_examples/cluster/plot_inductive_clustering.html#sphx-glr-auto-examples-cluster-plot-inductive-clustering-py. This could look as follows:

from joblib import Parallel, delayed

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.base import clone
from sklearn.utils.metaestimators import if_delegate_has_method

from gtda.utils import check_collection

class ForEachInput(BaseEstimator, TransformerMixin):
    def __init__(self, transformer, outer_n_jobs=None, outer_prefer=None):
        self.transformer = transformer
        self.outer_n_jobs = outer_n_jobs
        self.outer_prefer = outer_prefer

    def fit(self, X, y=None):
        check_collection(X)

        self._is_fitted = True
        return self

    @if_delegate_has_method(delegate="transformer")
    def fit_transform(self, X, y=None):
        Xt = check_collection(X)

        Xt = Parallel(n_jobs=self.outer_n_jobs, prefer=self.outer_prefer)(
            delayed(clone(self.transformer).fit_transform)(x) for x in Xt
            )
        return Xt

Notes:

A factory for transformers, taking a scikit-learn transformer class as argument

Here, we provide the user instead with a factory for producing "collection-wise" versions of arbitrary transformers. The user would e.g. set PCACollection = apply_transformer_to_collection(PCA), and then instantiates it with the same parameters as PCA, e.g. pca_for_collection = PCACollection(n_components=3). The other points are as before. Here, more work had to be made to fix the __init__ signature of the wrapper class at runtime to make it the same as PCA.

import inspect

from joblib import Parallel, delayed

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.base import clone
from sklearn.utils.metaestimators import if_delegate_has_method

from gtda.utils import check_collection

def apply_transformer_to_collection(estimator_cls):
    estimator_init = getattr(estimator_cls.__init__, 'deprecated_original', estimator_cls.__init__)
    estimator_init_signature = inspect.signature(estimator_init)
    default_estimator_parameters = [
        p for p in estimator_init_signature.parameters.values()
        if p.name != "self" and p.kind != p.VAR_KEYWORD
        ]

    class ForEachInput(BaseEstimator, TransformerMixin):
        def __init__(self, outer_n_jobs=None, outer_prefer=None, **estimator_params):
            self.outer_n_jobs = outer_n_jobs
            self.outer_prefer = outer_prefer
            estimator_params = {
                p.name: (p.default if p.name not in estimator_params else estimator_params[p.name])
                for p in default_estimator_parameters
                }
            estimator_cls.__init__(self, **estimator_params)

        __init__.__signature__ = inspect.Signature(
            [p for p in inspect.signature(__init__).parameters.values()][:-1] + \
            default_estimator_parameters
            )

        def fit(self, X, y=None):
            estimator_params = {p.name: getattr(self, p.name)
                                for p in default_estimator_parameters}
            self.estimator_ = estimator_cls(**estimator_params)

            return self

        def fit_transform(self, X, y=None):
            Xt = check_collection(X)

            self.fit(X, y=y)

            Xt = Parallel(n_jobs=self.outer_n_jobs, prefer=self.outer_prefer)(
                delayed(clone(self.estimator_).fit_transform)(x) for x in Xt
                )
            return Xt

    ForEachInput.__name__ = "ForEachInput" + estimator_cls.__name__

    return ForEachInput

Notes:

Main differences

With the first version, we have an easier to maintain, less bug-prone and more scikit-learn-looking solution. However, the second solution has a more functional flavour which could be appealing to some -- having a transformer factory makes for slightly clearer thinking and code layout if one is set on which transformer should be used.

Users who want to e.g. grid-search across transformer classes, and not just parameters of a fixed class, will find the first solution easier to work with. But if they are set on which class to use, the grid would look a little more cumbersome as they have to access deep parameters by prepending the class name each time, as in PCA__n_components.

See also #108.

wreise commented 4 years ago

Yes :heart_eyes: ! (I would vote for the solution based on the factory pattern )

ulupo commented 4 years ago

@wreise: @gtauzin and I had a chat about this and we agreed that given that some aspects of the second solution are difficult to test comprehensively in a scikit-learn context (dynamical signature and class name changes in particular), and given also that the first solution allows for the estimator type to be treated as a hyperparameter (good for some users), we'll be pushing the first version for the time being, but look for a way of integrating something like the second one too in the future. I'll make a PR.

ulupo commented 3 years ago

495 closes this issue.