dswah / pyGAM

[HELP REQUESTED] Generalized Additive Models in Python
https://pygam.readthedocs.io
Apache License 2.0
857 stars 157 forks source link

Fix pyGAM to work with SkLearn GridSearchCV #267

Open judahrand opened 4 years ago

judahrand commented 4 years ago

This pull request aims to resolve #247.

Minimal example:

import numpy as np
import pandas as pd
from sklearn.datasets import load_boston
from sklearn.model_selection import GridSearchCV

import pygam

def gam(x, y):
    lams = np.random.rand(10, x.shape[1])
    lams = np.exp(lams)
    linear_gam = pygam.LinearGAM(n_splines=10, max_iter=1000)
    parameters = {
        'lam': [x for x in lams]
    }
    gam_cv = GridSearchCV(linear_gam, parameters, cv=5, iid=False, return_train_score=True,
 refit=True, scoring='neg_mean_squared_error')
    gam_cv.fit(x, y)
    cv_results_df = pd.DataFrame(gam_cv.cv_results_).sort_values(by='mean_test_score', ascending=False)
    return gam_cv, cv_results_df

if __name__ == "__main__":
    X, y = load_boston(return_X_y=True)
    gam_cv, cv_results_df = gam(X, y)
    print(gam_cv)
    print(cv_results_df.head())
codecov[bot] commented 4 years ago

Codecov Report

Merging #267 into master will increase coverage by 0.13%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #267      +/-   ##
==========================================
+ Coverage   95.05%   95.19%   +0.13%     
==========================================
  Files          22       22              
  Lines        3178     3184       +6     
==========================================
+ Hits         3021     3031      +10     
+ Misses        157      153       -4     
Impacted Files Coverage Δ
pygam/pygam.py 94.83% <100.00%> (+0.03%) :arrow_up:
pygam/utils.py 87.73% <0.00%> (+0.30%) :arrow_up:
pygam/tests/test_GAM_methods.py 100.00% <0.00%> (+0.36%) :arrow_up:
pygam/tests/test_utils.py 96.50% <0.00%> (+1.39%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 2942579...2064bea. Read the comment docs.