jdb78 / pytorch-forecasting

Time series forecasting with PyTorch
https://pytorch-forecasting.readthedocs.io/
MIT License
3.87k stars 611 forks source link

Weights init on N-Beats? #439

Closed pnmartinez closed 3 years ago

pnmartinez commented 3 years ago

Hi @jdb78 ,

I am trying to customize your N-Beats implementation to use SELU() instead of ReLU() activation function.

It's easy to replace it and use PyTorch's implementation, but SELU() "demands lecun_normal for weight initialization, and if dropout wants to be applied, one should use AlphaDropout" (source).

I searched your codebase, but I can only see weight inits on the TFT. Not even in the BaseModel that NBeats inherits.

Can you give any pointer on how is N-Beats initializating its weights? Maybe the specific torch class?

Cheers and keep up the good work!

pnmartinez commented 3 years ago

Hi,

I think I cracked it! :smile:

I was not familiar with weight initialization in PyTorch. Now I know it's messy, and the API will be revamped in that regard on version 1.9.

However, I think I implemented everything SELU requires, i.e. 3 changes:

If anyone wants to double check my implementation, or finds it useful, here you have it: https://gist.github.com/pnmartinez/fef1f488497fa85a2cc1626af2a5b4bd