The official implementation of the paper,
SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning
:large_orange_diamond: Note: The extended version of SubTab with codes and pre-processed data for Adult Income and BlogFeedback datasets can be found at: https://github.com/talipucar/SubTab_extended
NeurIPS 2021 slides | NeurIPS 2021 poster |
---|---|
We used Python 3.7 for our experiments. The environment can be set up by following three steps:
pip install pipenv # To install pipenv if you don't have it already
pipenv install --skip-lock # To install required packages.
pipenv shell # To activate virtual env
If the second step results in issues, you can install packages in Pipfile individually by using pip i.e. "pip install package_name".
MNIST dataset is already provided to demo the framework. For your own dataset, follow the instructions in Adding New Datasets.
There are two types of configuration files:
1. runtime.yaml
2. mnist.yaml
runtime.yaml
is a high-level configuration file used by all datasets to:
Second configuration file is dataset-specific and is used to configure the architecture of the model, loss functions, and so on.
You can train and evaluate the model by using:
python train.py # For training.
python eval.py # For evaluation
train.py
will also run evaluation at the end of the training. eval.py
../utils/arguments.py
-h
argument to get help when running scripts.-d dataset_name
to run scripts on new datasets For each new dataset, you can use the following steps:
Provide a _load_dataset_name()
function, similar to MNIST load function
_load_tcga()
for tcga dataset, or _load_income()
for income dataset. Add a separate elif
condition in this section within _load_data()
method of TabularDataset()
class in utils/load_data.py
Create a new config file with the same name as dataset name.
For example, tcga.yaml
for tcga dataset, or income.yaml
for income dataset.
You can also duplicate one of the existing configuration files (e.g. mnist.yaml), and re-name it.
Make sure that the new config file is under config/
directory.
Provide data folder with pre-processed training and test set, and place it under ./data/
directory.
You can also do train-test split and pre-processing within your custom _load_dataset_name()
function.
(Optional) If you want to place the new dataset under a different directory than the local "./data/", then:
Place the dataset folder anywhere, and define the root directory to it in this line
of /config/runtime.yaml
.
For example, if the path to tcga dataset is /home/.../data/tcga/
,
you only need to include /home/.../data/
in runtime.yaml
.
The code will fill in tcga
folder name from the name given in the command line argument
(e.g. -d dataset_name
. In this case, dataset_name would be tcga).
- train.py - eval.py - src |-model.py - config |-runtime.yaml |-mnist.yaml - utils |-load_data.py |-arguments.py |-model_utils.py |-loss_functions.py ... - data |-mnist ... - results | ...
Results at the end of training is saved under ./results
directory. Results directory structure is as following:
- results |-dataset name |-evaluation |-clusters (for plotting t-SNE and PCA plots of embeddings) |-reconstructions (not used) |-training |-model_mode (e.g. ae for autoencoder) |-model |-plots |-loss
You can save results of evaluations under "evaluation" folder.
MLFlow is used to track experiments. It is turned off by default, but can be turned on by changing option on this line in
runtime config file in ./config/runtime.yaml
@article{ucar2021subtab,
title={SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning},
author={Ucar, Talip and Hajiramezanali, Ehsan and Edwards, Lindsay},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
If you use SubTab framework in your own studies, and work, please cite it by using the following:
@Misc{talip_ucar_2021_SubTab,
author = {Talip Ucar},
title = {{SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning}},
howpublished = {\url{https://github.com/AstraZeneca/SubTab}},
month = June,
year = {since 2021}
}