eminyous / fipe

https://arxiv.org/abs/2408.16167
MIT License
2 stars 1 forks source link

FIPE: Functionally Identical Pruning of Ensembles

PyPI Supported Python
versions test

This repository provides methods for Functionally-Identical Pruning of Tree Ensembles (FIPE). Given a trained scikit-learn model, FIPE provides a pruned model that is certified to be equivalent to the original model on the entire feature space. The algorithm is described in detail in the paper: https://arxiv.org/abs/2408.16167 .

Installation

This project requires the gurobi solver. Free academic licenses are available. Please consult:

Run the following commands from the project root to install the requirements. You may have to install python and venv before.

    virtualenv -p python3.10 env
    pip install fipepy

The installation can be checked by running the test suite:

    pip install pytest
    pytest

The integration tests require a working Gurobi license. If a license is not available, the tests will pass and print a warning.

Getting started

A minimal working example to prune an AdaBoost ensemble is presented below.

    from fipe import FIPE, FeatureEncoder
    import pandas as pd
    import numpy as np
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.ensemble import AdaBoostClassifier

    # Load data encode features
    data = load_iris()
    X = pd.DataFrame(data.data)
    y = data.target

    encoder = FeatureEncoder(X)
    X = encoder.X.to_numpy()

    # Train tree ensemble
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    base = AdaBoostClassifier(algorithm="SAMME", n_estimators=100)
    base.fit(X, y)

    # Read and normalize weights
    w = base.estimator_weights_
    w = (w / w.max()) * 1e5

    # Prune using FIPE
    norm = 1
    print(f'Pruning model by minimizing l_{norm} norm.')
    pruner = FIPE(base=base, weights=w, encoder=encoder, norm=norm, eps=1e-6)
    pruner.build()
    pruner.add_samples(X_train)
    pruner.oracle.setParam('LogToConsole', 0)
    pruner.prune()
    print('\n Finished pruning.')

    # Read pruned model
    n_activated = pruner.n_activated
    print('The pruned ensemble has ', n_activated, ' estimators.')

    # Verify functionally-identical on test data
    y_pred = base.predict(X_test)
    y_pruned = pruner.predict(X_test)
    fidelity = np.mean(y_pred == y_pruned)
    print('Fidelity to initial ensemble is ', fidelity, '%.')