Open BalzaniEdoardo opened 1 year ago
+1 to add sklearn
as a dependency.
Some additional thoughts: score for scikit-learn has to return a scalar; this may create problems if we want to regularize the model (it is basically mandatory with real data) and fit the population jointly. Here is why:
1) the amount of regularization needed for a neuron with shallow tuning will be very different from that of a neuron strongly tuned.
2) cross-validation should learn a regularization hyper-parameter for each neuron. This can be done in parallel by grid-search cross-validation, fitting jointly and returning a score for each neuron instead of a single scalar. For each neuron we can pick the regularizer with the best cv-score. This works because the neuron are assumed independent conditional on the past spike history, so we don't need to grid search over all the combination of the parameters.
3) the score method at point (2) violates scikit-learn api, and is not compatible with sklearn.model_selection
module, so we would need to code the cross-validation from scratch.
In my experience calling sklearn GridSearchCV
one neuron at the time on the GPU was still way more efficient than fitting the data with sklearn PoissonRegressor
or scipy.minimize
.
The difference is massive, with V1 recordings of about 10 mins, I was able to fit a 5-fold cv Ridge-GLM on the gpu over the whole population (about 70 neurons) in 2 hours on the gpu, after 2 days PoissonRegressor
hadn't finish yet.
Fitting jointly will be likely even faster, but the trade-off is compatibility with sklearn model selection and any pipeline involving a call to score.
One way to go could be to implement a fit_population
and score_population
, as well as a K-fold grid-search cv that relies on those methods as a fast option. Possibly of fitting neurons in batches (the full population may be still to large in size to fit in any gpu), while keeping fit and score as they are if one wants to rely on scikit-learn pipelines fitting a neuron at the time.
I am less sure we need the scikit-learn
dependency, BaseEstimator
is not a generic base class, it has a lot of methods for validating the input and other methods that are scikit-learn
specific and we make no use of them yet. Over all, for 3 method that we need, we end up with a machinery that we either not use, or that is not aligned with our code architecture
Here is a list of methods from the scikit-learn BaseEstimator
class along with a brief description for each, including if we need them in nemos.Base
[Needed], if they may be useful [Maybe Useful], or if we do not want them [Unwanted]:
[Needed] _get_param_names(cls): Retrieves the names of the parameters for the estimator. Inspects the constructor's arguments and raises an error if varargs are used.
[Needed] get_params(self, deep=True): Returns a dictionary of all parameters for this estimator. If deep
is True, it includes parameters of subobjects that are estimators.
[Needed] set_params(self, \params)**: Sets the parameters of the estimator. It supports both simple and nested objects, updating parameters for each component of a nested object.
[Maybe Useful] __sklearn_clone__(self): Returns a clone of the estimator instance.
[Maybe Useful] repr(self, N_CHAR_MAX=700): Provides a string representation of the estimator, with optional character limit N_CHAR_MAX
. Uses compact formatting and ellipsis for long sequences.
[Maybe Useful] getstate(self): Handles getting the state for pickling. Raises a TypeError if __slots__
are used. deep copy the self.dict
[Not Useful?] setstate(self, state): Sets the state when unpickling. Includes version compatibility checks and warnings. This checks that the unpickled model is compatible with the sklearn version. Probably usless for us.
[Not Useful?] _more_tags(self): Returns additional tags for the estimator. Used for providing custom tags in subclasses, recursively collects the tags of all subclasses.
[Not Useful?] _get_tags(self): Collects tags from the class hierarchy. Aggregates tags from _more_tags
methods of base classes.
[Unwanted] _check_n_features(self, X, reset): Checks the number of features in X
against n_features_in_
. The reset
parameter determines set/check behavior.
[Unwanted] _check_feature_names(self, X, reset): Sets or checks feature_names_in_
, ensuring feature name consistency between X
and training data.
[Unwanted] _validate_data(self, X, y, reset, validate_separately, cast_to_ndarray, \check_params)**: Validates input data X
and y
, setting/checking n_features_in_
and feature_names_in_
. Handles different validation scenarios.
Note on Tag in Scikit-learn
In scikit-learn, a tag is a simple way to provide meta-information about an estimator. Tags allow developers to specify certain properties of the estimator, such as whether it requires the target data y
during fitting, supports multioutput, or works with sparse matrices. Tags are used internally by scikit-learn to make decisions about handling different estimators in a generic manner. They are defined using the _get_tags()
method and can be overridden in custom estimators to provide specific behavior.
this is an example of what get_tags returns on Poisson regressor
model = lin.PoissonRegressor()
model._get_tags()
Out[6]:
{'array_api_support': False,
'non_deterministic': False,
'requires_positive_X': False,
'requires_positive_y': True,
'X_types': ['2darray'],
'poor_score': False,
'no_validation': False,
'multioutput': False,
'allow_nan': False,
'stateless': False,
'multilabel': False,
'_skip_test': False,
'_xfail_checks': False,
'multioutput_only': False,
'binary_only': False,
'requires_fit': True,
'preserves_dtype': [numpy.float64],
'requires_y': True,
'pairwise': False}
We may think to add
sklearn
as dependency directly inheritsklearn.BaseEstimator
directly instead of implementing our own base class.The reasoning behind this is the following: 1) less code to maintain 2) ensured compatiblity with meta-data routing, which is still experimental in scikit-learn, but could be potentially useful and hard to implement/maintain on our own.
If we don't care about routing, implementing and maintaining our own get/set is light-work, so I would be fine with maintain compatibility and avoid a dependency.