XDZhelheim / STAEformer

[CIKM'23] Official code for our paper "Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting".
https://arxiv.org/abs/2308.10425
132 stars 16 forks source link

Change input_dim to 1 #3

Closed AlexanderProchnow closed 9 months ago

AlexanderProchnow commented 9 months ago

I noticed that the input_dim parameter in the STAEformer class as well as in the STAEformer.yaml is set to 3. However, in line 200 of STAEformer.py (x = x[..., : self.input_dim]) we extract everything from the input up to input_dim, which would mean also the time of day and the day of week values. Afterwards we apply an input projection to x, which in the paper is only applied to the traffic volume. I believe input_dim=1 would therefore be correct?

XDZhelheim commented 9 months ago

Hello, thanks for your feedback. The point you raised is valid: our paper's description of C=1 is not accurate. Using the time indices as input features is a commonly adopted approach in related works such as DCRNN, Graph WaveNet, and STID. In our code, you can modify the input_dim parameter to select which features to use. By default, we use all available features (following the code of the models mentioned above). We appreciate you bringing this issue to our attention so we can improve our explanation next time.