interpretml / interpret

Fit interpretable models. Explain blackbox machine learning.
https://interpret.ml/docs
MIT License
6.31k stars 735 forks source link

model fitting very slow (colab) #585

Closed twright8 closed 3 weeks ago

twright8 commented 3 weeks ago

Hi! Thanks in advance.

I am using optuna for hyperparameter tuning, but each fit is incredibly slow (this is true outside of optuna just using standard parameters). One fit will often take 2 hours, and a full cv much longer. I have quite a few variables, but not a huge amount of data:

<html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40">

Snippet of my data:   | a | b | c | d | e | f | g | h | i | j | k | l | m | n | o | p | q | r | s | t | country | year -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- 1 | 97.7 |   | 28.40251 |   | 85.95955 | 80.28806 | 74.81396 | 11.71 | 61.6 | 73.974 | 15.53361 | 59.75228 |   | 2568.342 |   |   | 62.575 |   | 620 | 19 | Afghanistan | 2020 2 | 97.7 |   | 27.49173 |   | 87.64932 | 74.23411 | 76.81837 | 11.224 | 63.6 | 74.246 | 14.83132 | 57.90825 |   | 2589.41 |   |   | 63.565 |   | 644 | 16 | Afghanistan | 2019 3 | 93.4 |   | 26.58131 | 0.63602 | 89.09145 | 71.3343 | 76.22302 | 11.206 | 65.8 | 74.505 | 14.20842 | 56.24823 |   | 2436.006 |   |   | 63.081 |   | 663 | 16 | Afghanistan | 2018 4 | 99.7 | 46.84473 | 84.40004 | 1.00668 | 43.15453 | 863.2817 | 24.82529 | 12.33 | 15 | 13.431 | 9.46475 | 25.14514 | 53.9 | 15165.06 | 5.3 | 13.7 | 75.109 | 93 | 59 | 35 | Brazil | 2018 … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | 2014 1296 | 32.3 | 31.10169 | 27.43137 |   | 85.8755 | 114.4443 | 24.83175 | 4.774 | 62.3 | 67.496 | 8.132273 | 35.81686 |   | 2618.241 |   |   | 58.846 | 88.69342 | 441 | 21 | Zimbabwe | 2019

Is there anything obvious that is leading to this? or is it expected?

imports

import optuna
from optuna.storages import JournalStorage, JournalFileStorage
from sklearn.model_selection import cross_validate
from sklearn.metrics import make_scorer, mean_squared_error, r2_score

Create storage that saves to a file

storage = JournalStorage(
    JournalFileStorage("optuna_study.log")
)

hyperparameters

def objective(trial):
    params = {
        'max_bins': 1024,
        'smoothing_rounds':trial.suggest_categorical('smoothing_rounds', [0, 25, 50, 75, 100, 150, 200, 350, 500, 750, 1000, 1500, 2000, 4000]),
        'max_interaction_bins': 64,
        'interactions': trial.suggest_float('interactions', 0, 1.0),
        'greedy_ratio': trial.suggest_float('greedy_ratio', 0.0, 10.0),
        'cyclic_progress': trial.suggest_float('cyclic_progress', 0.0, 1.0),
        'outer_bags': 14,
        'inner_bags': 20,
        'learning_rate': trial.suggest_float('learning_rate', 0.0025, 0.2, log=True),
        'validation_size': trial.suggest_float('validation_size', 0.05, 0.4),
        'max_leaves': trial.suggest_int('max_leaves', 2, 3),
        'min_samples_leaf': trial.suggest_int('min_samples_leaf', 2, 10),
        'min_hessian': trial.suggest_float('min_hessian', 0.000001, 0.01, log=True),
        'early_stopping_rounds': 100,
        'random_state': 42,
        'max_rounds':25000
    }
    print(params)
    model = ExplainableBoostingRegressor(**params, feature_names=feat, n_jobs=1)
    # Combine the DataFrames

    scoring = {
        'r2': 'r2',
        'mse': make_scorer(mean_squared_error)
    }

    print("now doing stuff")
    from sklearn.model_selection import GroupKFold
    cv = GroupKFold(n_splits=5)
    cv_results = cross_validate(model, X_transformed, y_transformed, cv=cv, scoring=scoring, groups=X_transformed['country'], n_jobs=1)

    print(f"R2 scores: {cv_results['test_r2']}")
    print(f"MSE scores: {cv_results['test_mse']}")
    print(f"Mean R2: {cv_results['test_r2'].mean()}")
    print(f"Mean MSE: {cv_results['test_mse'].mean()}")

    return cv_results['test_r2'].mean()

create study

if testing == True:
    try:
        study = optuna.load_study(
            study_name="maternal_mortality_study",
            storage=storage
        )
        print("Loaded existing study")
    except:
        study = optuna.create_study(
            study_name="maternal_mortality_study",
            storage=storage,
            direction='maximize',
            load_if_exists=True
        )
        print("Created new study")

    # start study
    study.optimize(objective, n_trials=40, show_progress_bar=True, n_jobs=2)
    optuna.visualization.plot_param_importances(study)
    best_hyperparameters = study.best_params

    # Get final fit
    model = ExplainableBoostingRegressor(**best_hyperparameters, feature_names=feat, n_jobs=-1)
    model.fit(X_transformed, y_transformed)
paulbkoch commented 3 weeks ago

Hi @twright8 -- Some of those hyperparameters are going to be fairly time-expensive. As the hyperparameter documentation mentions (https://interpret.ml/docs/hyperparameters.html), inner_bags in particular is very expensive. By setting inner_bags to 20, it will take about 20 times longer than using the default of 0. If you'd like to reduce your fitting time, I'd also recommend eliminating the smoothing_rounds above 1000, and perhaps even anything above 500. The rest of your hyperparameters look reasonable to me time-wise.

If you're still having fitting time issues, I might recommend a slightly more complicated staged process: 1) Fit a single EBM with defaults. 2) Choose some portion of the features to eliminate. There are many ways to do this, but a simple one would be through feature_importance. 3) Use optuna on the reduced dataset. 4) Once optuna has determined hyperparameters, you could optionally retrain an EBM on the full dataset using the discovered hyperparameters. 5) If you want even better performance, you could try fitting one more EBM using inner_bags=20.

I haven't tried this procedure myself, but it seems like something that might work.

twright8 commented 3 weeks ago

Thanks so much @paulbkoch. It was slow on (what i thought) was default settings, but it turns out i accidentally was running inner bags on 20 still so i assume thats the issue! thanks so much