timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
5.1k stars 639 forks source link

Transformer Model and TST are not converging. #634

Open HasnainKhanNiazi opened 1 year ago

HasnainKhanNiazi commented 1 year ago

I am working on a regression problem where I am using TransformerModel and TST for training. My dataset and model config can be seen below.

Dataset For both models Window Length = 100 Features at one time step = 94 I am using batch_tfms=TSStandardize(by_var=True) as it has been shown in the original paper also.

Model Config Transformer d_model=768 n_head=12 n_layers=12 loss=MSELossFlat

Model Config TST n_layers=12 d_model=768 n_heads=12

TransformerModel is taking around 3 hours for one epoch and right now, the 34th epoch is in training but the lowest validation loss that I got for TransformerModel and TST was at the 9th epoch but after that, both models are not converging.

My dataset looks like this,

A B C D E F G H
34 19.5 19.5 1 0.1 0 -35.7742 -2.25
34 19.5 19.5 1 -0.1 0 -39.1072 -2.25
34 19.5 19.5 1 0 0 -38.885 -2.5
34 19.5 19.5 1 1 0 --38.6628 -2.5

For obvious reasons, I am not able to post the whole dataset. Any help will be appreciated. Thanks

oguiza commented 1 year ago

Hi @HasnainKhanNiazi, Here are a few comments:

TransformerModel is taking around 3 hours for one epoch and right now, the 34th epoch is in training but the lowest validation loss that I got for TransformerModel and TST was at the 9th epoch but after that, both models are converging.

Are you using a GPU? This is a really long time per epoch, unless your dataset is huge. Do you mean converging or diverging? Based on what you say, it seems your models may be overfitting (your key metric mse is growing after some time). In case of overfitting, Jeremy Howard (fastai) recommends the following steps (in order):

HasnainKhanNiazi commented 1 year ago

Hi @oguiza, thanks for your insights. Yes, I am using a GPU (Nvidia A100) for training. It is taking 3 hours for one epoch as the dataset is really huge. I don't think the model is overfitting as the training loss is quite huge but in the case of overfitting, the training loss shouldn't be that huge.

I will change the model architecture for sure, I was trying to recreate BERT for the regression problem as BERT was having the same config I am using.

I am attaching an image of the training, it may help find out the core problem.

Screenshot from 2022-12-05 11-55-17

EDIT: I am also using MetaDataSet as I have the data distributed in multiple files. Length of len(mdset) is 4083010.

oguiza commented 1 year ago

Hi @HasnainKhanNiazi, Looking at the losses, the model is not learning anything. Something I'd recommend is that you use a small dataset to train. This way, you can run multiple iterations until so see it starts learning. Then you can scale up. Looking at the large loss, it seems the issue is related to how you are scaling the data.

HasnainKhanNiazi commented 1 year ago

Thanks @oguiza , I will train on smaller chunks, I will keep this issue open for now and will close it after getting to a conclusion. I will post an update here also. Thanks

HasnainKhanNiazi commented 1 year ago

Hi @oguiza , I have been doing some experiments with transformers and I have implemented some basic architectures such as;

  1. CNN + Vanilla Encoders + CNN
  2. Vanilla Encoders + CNN
  3. Vanilla Encoders + MLP

And all of these models are learning and validation loss is decreasing but when it comes to using the same data with TST and TSTPlus architectures, models aren't learning anything. I am not sure what could be wrong as I am doing the same data preprocessing in both cases but TST and TSTPlus models aren't learning anything.

oguiza commented 1 year ago

Hi @HasnainKhanNiazi , Sorry for the late reply. Could you please paste a code snippet to reproduce the issue? I have not been able to reproduce it. Have you tried using your approach with any of the datasets available in tsai?