romilbert / samformer

Official implementation of SAMformer, a transformer leveraging Sharpness-Aware Minimization and Channel-Wise Attention for Time Series Forecasting.
MIT License
130 stars 18 forks source link

SAMformer (ICML'24 Oral)

This repository contains the official implementation of SAMformer, a transformer-based model for time series forecasting from

SAMformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention. Romain Ilbert, Ambroise Odonnat, Vasilii Feofanov, Aladin Virmaux, Giuseppe Paolo, Themis Palpanas, Ievgen Redko.
*Equal contribution.

Click here to access the ICML oral presentation on SAMformer.

Overview

SAMformer is a lightweight transformer architecture designed for time series forecasting. It uniquely integrates Sharpness-Aware Minimization (SAM) with a Channel-Wise Attention mechanism. This method provides state-of-the-art performance in multivariate long-term forecasting across various forecasting tasks. In particular, SAMformer surpasses TSMixer by $\mathbf{14.33}$% on average, while having $\mathbf{\sim4}$ times fewer parameters, and iTransformer and PatchTST by $\mathbf{6.58}$% and $\mathbf{8.79}$% respectively.

Architecture

SAMformer takes as input a $D$-dimensional time series of length $L$ (look-back window), arranged in a matrix $\mathbf{X}\in\mathbb{R}^{D\times L}$ and predicts its next $H$ values (prediction horizon), denoted by $\mathbf{Y}\in\mathbb{R}^{D\times H}$. The main components of the architecture are the following.

πŸ’‘ Shallow transformer encoder. The neural network at the core of SAMformer is a shallow encoder of a simplified Transformer. Channel-wise attention is applied to the input, followed by a residual connection. Instead of the usual feedforward block, a linear layer is directly applied on top of the residual connection to output the prediction.

πŸ’‘ Channel-Wise Attention. Contrary to the usual temporal attention in $\mathbb{R}^{L \times L}$, the channel-wise self-attention is represented by a matrix in $\mathbb{R}^{D \times D}$ and consists of the pairwise correlations between the input's features. This brings two important benefits:

πŸ’‘ Reversible Instance Normalization (RevIN). The resulting network is equipped with RevIN, a two-step normalization scheme to handle the shift between the training and testing time series.

πŸ’‘ Sharpness-Aware Minimization (SAM). As suggested by our empirical and theoretical analysis, we optimize the model with SAM to make it converge towards flatter minima, hence improving its generalization capacity.

SAMformer uniquely combines all these components in a lightweight implementation with very few hyperparameters. We display below the resulting architecture.

Results

We conduct our experiments on various multivariate time series forecasting benchmarks.

πŸ₯‡ Improved performance. SAMformer outperforms its competitors in $\mathbf{7}$ out of $\mathbf{8}$ datasets by a large margin. In particular, it improves over its best competitor TSMixer+SAM by $\mathbf{5.25}$%, surpasses the standalone TSMixer by $\mathbf{14.33}$%, and the best transformer-based model FEDformer by $\mathbf{12.36}$%. In addition, it improves over the vanilla Transformer by $\mathbf{16.96}$%. For each dataset and horizon, SAMformer is ranked either first or second.

πŸš€ Computational efficiency and versatility. SAMformer has a lightweight implementation with few learnable parameters, contrary to most of its competitors, leading to improved computational efficiency. SAMformer significantly outperforms the SOTA in multivariate time series despite having fewer parameters. In addition, the same architecture is used for all the datasets, while most of the other baselines require heavy hyperparameter tuning, which showcases the versatility of our approach.

πŸ“š Qualitative benefits. We display in our paper the benefits of SAMformer in terms of smoothness of the loss landscape, robustness to the prediction horizons, and signal propagation in the attention layer.

Installation

To get started with SAMformer, clone this repository and install the required packages.

git clone https://github.com/romilbert/samformer.git
cd SAMformer
pip install -r requirements.txt

Make sure you have Python 3.8 or a newer version installed.

Modules

SAMformer consists of several key modules:

Usage

To launch the training and evaluation process, use the run_script.sh script with the appropriate arguments :

sh run_script.sh -m [model_name] -d [dataset_name] -s [sequence_length] -u -a

Script Arguments

Example

sh run_script.sh -m transformer -d ETTh1 -u -a

Open-source Participation

Do not hesitate to contribute to this project, we would be happy to receive feedback and integrate your suggestions.

Licence

The code is distributed under the MIT license.

Authors

Romain Ilbert designed the methodology, developed the codebase and led the experiments. Ambroise Odonnat designed the methodology, developed the theory and led the writing. Vasilii Feofanov provided the PyTorch implementation of SAMformer. All authors contributed to discussions and writing. Correspondence to romain.ilbert@hotmail.fr and ambroiseodonnattechnologie@gmail.com.

Acknowledgements

We would like to express our gratitude to all the researchers and developers whose work and open-source software have contributed to the development of SAMformer. Special thanks to the authors of SAM, TSMixer, RevIN and $\sigma$Reparam for their instructive works, which have enabled our approach. We provide below a non-exhaustive list of GitHub repositories that helped with valuable code base and datasets:

Citation

If you find this work useful in your research, please cite:

@InProceedings{ilbert2024samformer,
  title =    {SAMformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention},
  author =       {Ilbert, Romain and Odonnat, Ambroise and Feofanov, Vasilii and Virmaux, Aladin and Paolo, Giuseppe and Palpanas, Themis and Redko, Ievgen},
  booktitle =    {Proceedings of the 41st International Conference on Machine Learning},
  year =     {2024},
  volume =   {235},
  publisher =    {PMLR},
  url =      {https://proceedings.mlr.press/v235/ilbert24a.html},
}