Closed pnmartinez closed 3 years ago
Hi,
I think I cracked it! :smile:
I was not familiar with weight initialization in PyTorch. Now I know it's messy, and the API will be revamped in that regard on version 1.9.
However, I think I implemented everything SELU requires, i.e. 3 changes:
If anyone wants to double check my implementation, or finds it useful, here you have it: https://gist.github.com/pnmartinez/fef1f488497fa85a2cc1626af2a5b4bd
Hi @jdb78 ,
I am trying to customize your N-Beats implementation to use
SELU()
instead ofReLU()
activation function.It's easy to replace it and use PyTorch's implementation, but SELU() "demands
lecun_normal
for weight initialization, and if dropout wants to be applied, one should useAlphaDropout
" (source).I searched your codebase, but I can only see weight inits on the TFT. Not even in the
BaseModel
thatNBeats
inherits.Can you give any pointer on how is N-Beats initializating its weights? Maybe the specific torch class?
Cheers and keep up the good work!