unit8co / darts

A python library for user-friendly forecasting and anomaly detection on time series.
https://unit8co.github.io/darts/
Apache License 2.0
7.88k stars 851 forks source link

Tips on speeding up training of the TFT #2344

Open chododom opened 4 months ago

chododom commented 4 months ago

Hi, I am working on trying to compare the prediction of a basic LSTM, a DeepAR and a TemporalFusionTransformer for the prediction of some multivariate sets of series.

I have around 4000 various TimeSeries objects, when put all together, it is a few million data points.

For the LSTM and DeepAR, the training times are acceptable. With the TFT, I am experiencing a really weird thing, where when I try to train a model with 2 attention heads, 1 RNN layer and hidden dimension of 8, an epoch takes about an hour, but if I increase the dimension from 8 to 10, suddenly an epoch is estimated to take 11 hours.

How is this possible, the model is only like 40k parameters according to the summary...

I am also using torch.set_float32_matmul_precision('medium') to try to speed up the training, but I'm having absolutely no luck.

Any explanations regarding the complexity or tips on improving the computation speed would be very much welcome, thank you!

dennisbader commented 4 months ago

Hi @chododom, TFTModel is a transformer model so it is more complex and more inefficient compared to the other models. From my checks, increasing the hidden size from 16 to 32 (factor 2), increased the number of trainable params by a factor 4. So increasing the hidden size from 8 to 10 (factor 1.25) leading to an increase in training time by a factor 11 does indeed sound strange (not saying yet that it is a bug though).

We would have to perform an in-depth analysis and profile the model to see whether this is normal. Currently, we don't have much capacity on our side for this as we're working on higher-prio tasks. So any help from the community would also be greatly appreciated :)

Also, we have some additional recommendations for model performance in our user guide for torch models.