romilbert / samformer

Official implementation of SAMformer, a transformer leveraging Sharpness-Aware Minimization and Channel-Wise Attention for Time Series Forecasting.
MIT License
71 stars 8 forks source link

SAMformer Paper (ICML'24 Oral)

The repository contains the official implementation of SAMformer, a transformer-based model for time series forecasting described in

Romain Ilbert, Ambroise Odonnat, Vasilii Feofanov, Aladin Virmaux, Giuseppe Paolo, Themis Palpanas, Ievgen redko. SAMformer: Unlocking the Potential of Transformers in Time Series Forecasting with Sharpness-Aware Minimization and Channel-Wise Attention.
*Equal contribution for writing the paper.

SAMformer Code and Experiments

This repository, the design and the implementation of SAMformer were developed by Romain Ilbert who also conducted the experiments presented in the paper. Paper slides can be found here.

Overview

SAMformer is a novel lightweight transformer architecture designed for time series forecasting. It uniquely integrates Sharpness-Aware Minimization (SAM) with a Channel-Wise Attention mechanism. This method provides state-of-the-art performance in multivariate long-term forecasting across various forecasting tasks. In particular, SAMformer surpasses TSMixer by $\mathbf{14.33}$% on average, while having $\mathbf{\sim4}$ times fewer parameters, and iTransformer and PatchTST by $\mathbf{6.58}$% and $\mathbf{8.79}$% respectively.

Architecture

SAMformer takes as input a $D$-dimensional time series of length $L$ (look-back window), arranged in a matrix $\mathbf{X}\in\mathbb{R}^{D\times L}$ and predicts its next $H$ values (prediction horizon), denoted by $\mathbf{Y}\in\mathbb{R}^{D\times H}$. The main components of the architecture are the following.

💡 Shallow transformer encoder. The neural network at the core of SAMformer is a shallow encoder of a simplified Transformer. Channel-wise attention is applied to the input, followed by a residual connection. Instead of the usual feedforward block, a linear layer is directly applied on top of the residual connection to output the prediction.

💡 Channel-Wise Attention. Contrary to the usual temporal attention in $\mathbb{R}^{L \times L}$, the channel-wise self-attention is represented by a matrix in $\mathbb{R}^{D \times D}$ and consists of the pairwise correlations between the input's features. This brings two important benefits:

💡 Reversible Instance Normalization (RevIN). The resulting network is equipped with RevIN, a two-step normalization scheme to handle the shift between the training and testing time series.

💡 Sharpness-Aware Minimization (SAM). As suggested by our empirical and theoretical analysis, we optimize the model with SAM to make it converge towards flatter minima, hence improving its generalization capacity.

SAMformer uniquely combines all these components in a lightweight implementation with very few hyperparameters. We display below the resulting architecture.

Results

We conduct our experiments on various multivariate time series forecasting benchmarks.

🥇 Improved performance. SAMformer outperforms its competitors in $\mathbf{7}$ out of $\mathbf{8}$ datasets by a large margin. In particular, it improves over its best competitor TSMixer+SAM by $\mathbf{5.25}$%, surpasses the standalone TSMixer by $\mathbf{14.33}$%, and the best transformer-based model FEDformer by $\mathbf{12.36}$%. In addition, it improves over the vanilla Transformer by $\mathbf{16.96}$%. For each dataset and horizon, SAMformer is ranked either first or second.

🚀 Computational efficiency and versatility. SAMformer has a lightweight implementation with few learnable parameters, contrary to most of its competitors, leading to improved computational efficiency. SAMformer significantly outperforms the SOTA in multivariate time series despite having fewer parameters. In addition, the same architecture is used for all the datasets, while most of the other baselines require heavy hyperparameter tuning, which showcases the versatility of our approach.

📚 Qualitative benefits. We display in our paper the benefits of SAMformer in terms of smoothness of the loss landscape, robustness to the prediction horizons, and signal propagation in the attention layer.

Installation

To get started with SAMformer, clone this repository and install the required packages.

git clone https://github.com/romilbert/samformer.git
cd SAMformer
pip install -r requirements.txt

Make sure you have Python 3.8 or a newer version installed.

Modules

SAMformer consists of several key modules:

Usage

To launch the training and evaluation process, use the run_script.sh script with the appropriate arguments :

sh run_script.sh -m [model_name] -d [dataset_name] -s [sequence_length] -u -a

Script Arguments

Example

sh run_script.sh -m transformer -d ETTh1 -u -a

License

This project is licensed under the MIT License. See the LICENSE file for more details.

Open-source Participation

Do not hesitate to contribute to this project by submitting pull requests or issues, we would be happy to receive feedback and integrate your suggestions.

Contact

Feel free to contact Romain Ilbert romain.ilbert@hotmail.fr - Ambroise Odonnat ambroiseodonnattechnologie@gmail.com in case of questions.

Acknowledgements

We would like to express our gratitude to all the researchers and developers whose work and open-source software have contributed to the development of SAMformer. Special thanks to the authors of SAM, TSMixer, RevIN and $\sigma$Reparam for their instructive works, which have enabled our approach. We provide below a non-exhaustive list of GitHub repositories that helped with valuable code base and datasets: