scikit-learn-contrib / lightning

Large-scale linear classification, regression and ranking in Python
https://contrib.scikit-learn.org/lightning/
1.73k stars 214 forks source link

CV with 1SE? #84

Open levithatcher opened 8 years ago

levithatcher commented 8 years ago

Love the fact that this package offers group lasso! You're probably aware, but with Lasso one often uses a cross validation combined with a one-standard-error (1SE) rule, where one chooses the model with fewest coefficients that's less than 1SE away from the sub-model with the lowest error.

Is there an example anywhere with your group lasso combined with cross-validation or any thoughts on implementing the 1SE functionality? Thanks again for this wonderful resource!

fabianp commented 8 years ago

Haven't seen such example, and didn't know about the technique. Feel free to submit an example if you think it could be useful On Jun 15, 2016 4:00 PM, "Levi Thatcher" notifications@github.com wrote:

Love the fact that this package offers group lasso! You're probably aware, but with Lasso one often uses a cross validation combined with a one-standard-error (1SE) rule, where one chooses the model with fewest coefficients that's less than 1SE away from the sub-model with the lowest error.

Is there an example anywhere with your group lasso combined with cross-validation or any thoughts on implementing the 1SE functionality? Thanks again for this wonderful resource!

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/scikit-learn-contrib/lightning/issues/84, or mute the thread https://github.com/notifications/unsubscribe/AAQ8hyvcta8VbiGvSZuaUw7tUzZrhdUmks5qMFn3gaJpZM4I2vJe .

levithatcher commented 8 years ago

Thanks, Fabian! Appreciate all this great work!

On Jun 18, 2016, at 5:06 PM, Fabian Pedregosa notifications@github.com wrote:

Haven't seen such example, and didn't know about the technique. Feel free to submit an example if you think it could be useful On Jun 15, 2016 4:00 PM, "Levi Thatcher" notifications@github.com wrote:

Love the fact that this package offers group lasso! You're probably aware, but with Lasso one often uses a cross validation combined with a one-standard-error (1SE) rule, where one chooses the model with fewest coefficients that's less than 1SE away from the sub-model with the lowest error.

Is there an example anywhere with your group lasso combined with cross-validation or any thoughts on implementing the 1SE functionality? Thanks again for this wonderful resource!

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/scikit-learn-contrib/lightning/issues/84, or mute the thread https://github.com/notifications/unsubscribe/AAQ8hyvcta8VbiGvSZuaUw7tUzZrhdUmks5qMFn3gaJpZM4I2vJe .

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or mute the thread.

bagibence commented 5 years ago

I'm not sure if this helps with the original problem, but someone might find it useful in the future. You can pass a callable to GridSearchCV as the refit parameter that takes the the dictionary given by the cv_results_ attribute of the GridSearchCV object and returns the index of the best parameter based on your criteria you define in your function. I used this to implement the "one standard error rule" for ridge regression.

def best_alpha_index(results):
    K = len([x for x in list(results.keys()) if x.startswith('split')])
    alpha_range = results['param_ridge__alpha'].data

    mean_per_alpha = pd.Series(results['mean_test_score'], index = alpha_range)
    std_per_alpha  = pd.Series(results['std_test_score'], index = alpha_range)
    sem_per_alpha  = std_per_alpha / np.sqrt(K)

    max_score  = mean_per_alpha.max()
    sem        = sem_per_alpha[mean_per_alpha.idxmax()]
    best_alpha = mean_per_alpha[mean_per_alpha >= max_score - sem].index.max()

    best_alpha_index = int(np.argwhere(alpha_range == best_alpha)[0])

    return best_alpha_index

Then used refit = best_alpha_index when constructing the GridSearchCV object.