autogluon / tabrepo

Apache License 2.0
36 stars 9 forks source link

TabRepo

TabRepo contains the predictions and metrics of 1530 models evaluated on 211 classification and regression datasets. This allows to compare against state-of-the-art AutoML systems or random configurations by querying precomputed results. We also store and expose model predictions so any ensembling strategy can also be benchmarked cheaply by just querying precomputed results.

We give scripts from our paper, TabRepo: A Large Scale Repository of Tabular Model Evaluations and its AutoML Applications, so that one can reproduce all experiments that compare different models and portfolio strategies against state-of-the-art AutoML systems.

The key features of the repo are:

tuning-impact.png sensitivity.png paper-figure.png

Installation

To install the repository, ensure you are using Python 3.9-3.11. Other Python versions are not supported. Then, run the following:

git clone https://github.com/autogluon/tabrepo.git
pip install -e tabrepo

Only Linux support has been tested. Support for Windows and MacOS is not confirmed, and you may run into bugs or a suboptimal experience (especially if you are unable to install ray).

Reproducing AutoML Conf 2024 Paper

If you are interested in reproducing the experiments of the paper, you will need these extra dependencies:

# Install AG benchmark, required only to reproduce results showing win-rate tables

git clone https://github.com/autogluon/autogluon-bench.git
pip install -e autogluon-bench

git clone https://github.com/Innixma/autogluon-benchmark.git
pip install -e autogluon-benchmark

# Install extra dependencies used for results scripts
pip install autorank seaborn

You are all set!

Quick-start

Recommended: Refer to examples/run_quickstart.py for a full runnable tutorial.

Now lets see how to do basic things with TabRepo.

Accessing model evaluations. To access model evaluations, you can do the following:

from tabrepo import load_repository

repo = load_repository("D244_F3_C1530_30")
repo.metrics(datasets=["Australian"], configs=["CatBoost_r22_BAG_L1", "RandomForest_r12_BAG_L1"])

The code will return the metrics available for the configuration and dataset chosen.

The example loads a smaller version of TabRepo with only a few datasets for illustrative purpose and shows the evaluations of one ensemble and how to query the stored predictions of a given model. When calling load_repository models predictions and TabRepo metadata will be fetched from the internet. We use a smaller version here as it can take a long time to download all predictions, in case you want to query all datasets, replace the context with D244_F3_C1530.

Querying model predictions. To query model predictions, run the following code:

from tabrepo import load_repository
repo = load_repository("D244_F3_C1530_30")
print(repo.predict_val_multi(dataset="Australian", fold=0, configs=["CatBoost_r22_BAG_L1", "RandomForest_r12_BAG_L1"]))

This will return predictions on the validation set. You can also use predict_test to get the predictions on the test set.

Simulating ensembles. To evaluate an ensemble of any list of configuration, you can run the following:

from tabrepo import load_repository
repo = load_repository("D244_F3_C1530_30")
print(repo.evaluate_ensemble(datasets=["Australian"], configs=["CatBoost_r22_BAG_L1", "RandomForest_r12_BAG_L1"]))

this code will return the error of an ensemble whose weights are computed with the Caruana procedure after loading model predictions and validation groundtruth.

Available Contexts

Context's are used to load a repository and are downloaded from S3 with the following code:

from tabrepo import load_repository
repo = load_repository(context_name)

Below is a list of the available contexts in TabRepo.

Context Name # Datasets # Folds # Configs Disk Size Notes
D244_F3_C1530 211 3 1530 330 GB All successful datasets. 64 GB+ memory recommended. May take a few hours to download.
D244_F3_C1530_200 200 3 1530 120 GB Used for results in paper. 32 GB memory recommended
D244_F3_C1530_175 175 3 1530 57 GB 16 GB memory recommended
D244_F3_C1530_100 100 3 1530 9.5 GB Ideal for fast prototyping
D244_F3_C1530_30 30 3 1530 1.1 GB Toy context
D244_F3_C1530_10 10 3 1530 220 MB Toy context
D244_F3_C1530_3 3 3 1530 33 MB Toy context

Reproducing paper experiments

To ensure reproducibility, you can use the AutoML2024 branch which provides a snapshot of the code that is able to reproduce the results.

To reproduce the experiments from the paper, run:

python scripts/baseline_comparison/evaluate_baselines.py

The experiment will require ~200GB of disk storage and 32GB of memory (although we use memmap to load model predictions on the fly, large datasets still have a significant memory footprint even for a couple of models). In particular, we used a m6i.4xlarge machine for our experiments which took under 24 hrs (less than $7 of compute with spot instance pricing). Excluding the 10-repeat seeded ablations, the experiments take under 1 hour.

All the table and figures of the paper will be generated under scripts/output/{expname}.

Colab Notebook

To run a subset of experiments on a Colab notebook, refer to https://colab.research.google.com/github/autogluon/tabrepo/blob/main/examples/TabRepo_Reproducibility.ipynb

Reproducing the raw TabRepo dataset

To reproduce the entire TabRepo dataset (context "D244_F3_C1530") from scratch, refer to the benchmark execution README.

To instead reproduce a small subset of the TabRepo dataset in a few minutes, run examples/run_quickstart_from_scratch.py.

Future work

We have been using TabRepo to study potential data leakage in AutoGluon stacking, we believe the following are also potential interesting future work directions:

Citation

If you find this work useful for you research, please cite the following:

@inproceedings{
  tabrepo,
  title={TabRepo: A Large Scale Repository of Tabular Model Evaluations and its Auto{ML} Applications},
  author={David Salinas and Nick Erickson},
  booktitle={AutoML Conference 2024 (ABCD Track)},
  year={2024},
  url={https://openreview.net/forum?id=03V2bjfsFC}
}