ganguli-lab / nems

Neural encoding models
MIT License
7 stars 5 forks source link

sklearn-compatible API? #3

Open neuromusic opened 6 years ago

neuromusic commented 6 years ago

@nirum would you be open to a PR refactoring the models in models.py to be scikit-learn compatible? or add them to a separate module?

nirum commented 6 years ago

Hey! Yes I totally would. I think it wouldn't be too much work either. There are already fit and predict methods on the model class, but I suspect they do not follow the sklearn API.

One issue--the features to the model are often multi-dimensional (e.g. spatiotemporal stimuli), so I would like to keep the ability to call model.fit(X) where X has shape [num_timesteps, num_pixels, num_pixels], for example. IIRC, sklearn expects the data matrix to always have the form [num_examples, num_features] ... am I remembering correctly?

neuromusic commented 6 years ago

Multidimensional features aren't an issue, per se. For example, there are

getting the fit/predict functions to conform is one piece, but there are internals that need to be consistent to be fully compatible (e.g. for taking advantage of sklearn's caching in hyperparameter searches).

I suspect the scope of work is approx:

More details are here: http://scikit-learn.org/stable/developers/contributing.html

Especially here: http://scikit-learn.org/stable/developers/contributing.html#rolling-your-own-estimator

nirum commented 6 years ago

Great, thanks for collecting these references! Happy to accept a PR. I can look into it myself, but realistically I won’t have time before early October.

On Sep 21, 2018, at 1:27 PM, Justin Kiggins notifications@github.com wrote:

Multidimensional features aren't an issue, per se. For example, there are

getting the fit/predict functions to conform is one piece, but there are internals that need to be consistent to be fully compatible (e.g. for taking advantage of sklearn's caching in hyperparameter searches).

I suspect the scope of work is approx:

inherit from sklearn BaseEstimator & RegressorMixin get .fit() & .predict() methods to be consistent w/ sklearn API refactor init to allow caching add test of API compliance using sklearn's check_estimator More details are here: http://scikit-learn.org/stable/developers/contributing.html

Especially here: http://scikit-learn.org/stable/developers/contributing.html#rolling-your-own-estimator

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or mute the thread.