This repository contains an implementation of SAINT (Self-Attention and Intersample Attention Transformer) using Pytorch-Lightning as a framework and Hydra for the configuration. Find the paper on arxiv_
Check the website_ for more information.
.. _arxiv: https://arxiv.org/abs/2106.01342 .. _Pytorch-Lightning: https://www.pytorchlightning.ai/ .. _Hydra: https://hydra.cc/ .. _website: https://actis92.github.io/lit-saint/
.. code-block:: bash
pip install lit-saint
Create an yaml file that contains the configuration needed by the application or use default values
Create an instance of SaintConfig using Hydra
Create the Dataframe that will be used for the model. In order to split correctly the data, you need to add a new column where you assign the label "train" to the rows of the training set, "validation" for the ones of the validation set and "test" for the testing one
.. code-block:: python3
data_module = SaintDatamodule(df=df, target="TARGET", split_column="SPLIT")
.. code-block:: python3
model = Saint(categories=data_module.categorical_dims, continuous=data_module.numerical_columns,
config=cfg, dim_target=data_module.dim_target)
.. code-block:: python3
pretrainer = Trainer(max_epochs=1)
trainer = Trainer(max_epochs=5)
.. code-block:: python3
saint_trainer = SaintTrainer(pretrainer=pretrainer, trainer=trainer)
saint_trainer.fit(model=model, datamodule=data_module, enable_pretraining=True)
.. code-block:: python3
prediction = saint_trainer.predict(model=model, datamodule=data_module, df=df_to_predict)
df_test["prediction"] = np.argmax(prediction, axis=1)
Preprocessing ^^^^^^^^^^^^^^
Some suggestions are:
.. code-block:: python3
from lit_saint import SaintConfig
from omegaconf import OmegaConf
conf = OmegaConf.create(SaintConfig)
with open("<FILE_NAME>", "w+") as fp:
OmegaConf.save(config=conf, f=fp.name)
In order to make type validation at runtime, you need to add at the beginning of your file the following lines:
.. code-block:: yaml
defaults:
- base_config
We would like to thank the repo with the official implementation of SAINT: https://github.com/somepago/saint