WenjieDu / PyPOTS

A Python toolkit/library for reality-centric machine/deep learning and data mining on partially-observed time series, including SOTA neural network models for scientific analysis tasks of imputation, classification, clustering, forecasting, & anomaly detection on incomplete industrial (irregularly-sampled) multivariate TS with NaN missing values
https://pypots.com
BSD 3-Clause "New" or "Revised" License
841 stars 84 forks source link

[Feature request] Is it possible to "warm-up" the transformer? #25

Closed b2jia closed 1 year ago

b2jia commented 1 year ago

Thank you for creating this wonderful resource! This is an amazing and useful tool!

Regarding SAITS, is it possible to pass a learning rate scheduler, rather than a fixed learning rate, for the transformer to pre-train?

I ask this because I compared the outputs of training 100 epochs vs 1000 epochs. The loss continues to decrease, but the error on holdout timepoints does not change between 100 vs 1000 epochs. Strangely, the prediction (after 100 & 1000 epochs) is less accurate than linear interpolation...! I wondered if it is because the transformers have too many parameters, and it needs some help learning initially.

WenjieDu commented 1 year ago

Hi there,

Thank you so much for your attention to PyPOTS! If you find PyPOTS helpful to your work, please star⭐️ this repository. Your star is your recognition, which can help more people notice PyPOTS and grow PyPOTS community. It matters and is definitely a kind of contribution.

I have received your message and will respond ASAP. Thank you for your patience! 😃

Best,
Wenjie

MaciejSkrabski commented 1 year ago

Since we are waiting for Wenjie's feedback, please allow me to chime in.

I ask this because I compared the outputs of training 100 epochs vs 1000 epochs. The loss continues to decrease, but the error on holdout timepoints does not change between 100 vs 1000 epochs.

From my (very) limited understanding of transformers, they learn blazingly fast! In my case, where LSTM needed thousands of epochs to converge, around 40 epochs would be sufficient. Consider monitoring validation loss more often in the initial epochs. Also, keep calm and lower learning rate!

By the way, did you remember to simulate missing data in the training data? It is described in the SAITS paper that the model needs to see two kinds of data, from which two kinds of errors are calculated: one for missing data reconstruction, the other for visible data approximation. If the model does not train on the actual task it is supposed to solve, it cannot.

Best of luck!

b2jia commented 1 year ago

This is great insight @MaciejSkrabski !

By the way, did you remember to simulate missing data in the training data?

I did, but I'm puzzled. I introduce artificially missing values (5%) into my already incomplete data (50% missing values). I assume during training, SAITS takes my artificial+incomplete input, and then does something similar to mcar - introduce its own artificially missing values (@WenjieDu what fraction is this, can you confirm?), impute, evaluate imputation loss (MIT). Otherwise, how does it train?

I currently use my artificial missing values as a "test" dataset, to evaluate final imputation loss (e.g.

imputation = saits.impute(X)  # impute the originally-missing values and artificially-missing values
mae = cal_mae(imputation, X_intact, indicating_mask)  # calculate mean absolute error on the ground truth (artificially-missing values)

Did I misunderstand the API - do I have to explicitly pass in the artificially missing value mask as input as well? I currently don't see an option to do so.

b2jia commented 1 year ago

After revisiting the SAITS paper, I'm wondering: what is the intuition for MIT_weight and ORT_weight? I find SAITS is able to reconstruct data with near perfect accuracy, but the imputation (at least on my dataset) is only slightly better than linear interpolation. Is it a matter of weighing the imputation more ie. MIT_weight=10, ORT_weight=1 or a matter of model architecture (needs more heads? deeper network?).

WenjieDu commented 1 year ago

Thank you for creating this wonderful resource! This is an amazing and useful tool!

Regarding SAITS, is it possible to pass a learning rate scheduler, rather than a fixed learning rate, for the transformer to pre-train?

I ask this because I compared the outputs of training 100 epochs vs 1000 epochs. The loss continues to decrease, but the error on holdout timepoints does not change between 100 vs 1000 epochs. Strangely, the prediction (after 100 & 1000 epochs) is less accurate than linear interpolation...! I wondered if it is because the transformers have too many parameters, and it needs some help learning initially.

Hi Bojing, thank you for raising this issue, and for your patience. And many thanks for your timely help, Maciej @MaciejSkrabski! I'm sorry for my delayed response.

Actually, I tried warm-up for Transformer and SAITS but I didn't obtain any notable improvement. I thought this trick should work for Transformers on the very-large datasets, like NLP corpus, but I may be wrong. Of course, you can write a scheduler to give it a try in your experiment settings. For quick action, you can use the schedulers from the lib Transformers. Please let me know if you have new discoveries.

WenjieDu commented 1 year ago

This is great insight @MaciejSkrabski !

By the way, did you remember to simulate missing data in the training data?

I did, but I'm puzzled. I introduce artificially missing values (5%) into my already incomplete data (50% missing values). I assume during training, SAITS takes my artificial+incomplete input, and then does something similar to mcar - introduce its own artificially missing values (@WenjieDu what fraction is this, can you confirm?), impute, evaluate imputation loss (MIT). Otherwise, how does it train?

I currently use my artificial missing values as a "test" dataset, to evaluate final imputation loss (e.g.

imputation = saits.impute(X)  # impute the originally-missing values and artificially-missing values
mae = cal_mae(imputation, X_intact, indicating_mask)  # calculate mean absolute error on the ground truth (artificially-missing values)

Did I misunderstand the API - do I have to explicitly pass in the artificially missing value mask as input as well? I currently don't see an option to do so.

The default artificially missing rate applied in MIT is 20% (check out the code here). I think your understanding of the API is right. Considering PyPOTS is still lacking detailed documentation and is under development, I recommend you try my code in the repo SAITS to impute your data.


After revisiting the SAITS paper, I'm wondering: what is the intuition for MIT_weight and ORT_weight? I find SAITS is able to reconstruct data with near perfect accuracy, but the imputation (at least on my dataset) is only slightly better than linear interpolation. Is it a matter of weighing the imputation more ie. MIT_weight=10, ORT_weight=1 or a matter of model architecture (needs more heads? deeper network?).

They are loss weights of according tasks. Weighting different losses is a common method in multi-task learning. Giving higher weights makes the model pay more attention to the corresponding task because the model gets more punishment. Hope this helps.

b2jia commented 1 year ago

Thanks for the response @WenjieDu , I will give the other repo a try. By the way, is it possible to inject the temporal domain into SAITS? For instance, the missing values at time points are unknown but at least the time points are known (and sometimes the time points are unevenly spaced).

WenjieDu commented 1 year ago

Like most imputation algorithms, the original SAITS assumes the input sampled with even time intervals. However, you can add the sampling timestamp as an additional feature of the input, or embed the timing into the positional encoding. But if data of features in your dataset is sampled irregularly (e.g. not all features are collected in one sampling operation), this may make your data more sparse.

It's an interesting question. Could you please give more details about your data and scenario? I'd love to know what kind of application needs such a function. And I can see what I can do to help further.

b2jia commented 1 year ago

Thank you again for responding! Will close this issue, emailed you directly.