sktime / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.98k stars 631 forks source link

Applying transfer learning with pytorch-forecasting #259

Closed ghost closed 3 years ago

ghost commented 3 years ago

Hi,

I have project where I think I can reuse a trained temporal_fusion_transformer model to a different target. I am a bit puzzled up to which part of the model weights I should best freeze, and which one to retrain.

I am guessing that I could start by just unfreezing the last (linear) output layer, and see how that works out; trial-and-error like. But, I was wondering if there where specific guidelines to follow with this type of neural network architecture.

I realize this is not really a question appropriate for a github issue, however, it seems at the moment better to post it here than on stack overflow. @jdb78 do you agree?

Thanks in any case!

Best regards,

Tomas

jdb78 commented 3 years ago

Agree. Generally speaking, you want to NOT freeze the normalization layers as well. There might be even a chance that you want to unfreeze the VariableSelectionNetwork and remove the embeddings and linear pre-scaling layers for continuous variables. Would be a great PR. Guess you could even make it a research paper given that transfer learning has been demonstrated to a large extend in time series forecasting.

ghost commented 3 years ago

Hi Jan, thanks for your response. If I end up with something worthwhile, I will let you know. Thanks a lot! Best regards, Tomas