lucidrains / product-key-memory

Standalone Product Key Memory module in Pytorch - for augmenting Transformer models
MIT License
72 stars 5 forks source link

Training insights #1

Closed david-macleod closed 3 years ago

david-macleod commented 4 years ago

Another great repo, thanks for sharing! I was just wondering if you had any advice at all for transformers augmented with these product key memory layers. Currently I have an 8/16 layer vanilla transformer stack for a language modelling task and I am trying to replace the feed forward block in the middle layer of the stack, as suggested in the paper, but it is making training quite unstable, with high variance in validation losses (and losses that are just generally higher).

I have tried increasing the learning rate for these parameters (using the default you suggest), which does improve performance but training remains unstable. I am also tracking the usage / KL divergence metrics but the KL in particular is much higher than in the paper, and is decreasing only very slowly as training progresses, which suggests the memory is not being fully exploited by the model.

I was just wondering if you had much experience training models using these layers and if you could share any insights/advice at all regarding observed behavior?

lucidrains commented 4 years ago

@david-macleod Hey David! Sorry for the late response, but I came across a paper https://arxiv.org/pdf/2010.03881.pdf which proposed, instead of substituting the FF with PKM, to put it in parallel with FF, so you'll have PKM + FF + residual. Would you be willing to give that a try and let me know if it checks out?

gyuwankim commented 3 years ago

@lucidrains Thanks for citing our paper. I recognize this repo now. @david-macleod I hope our proposed method would help your case. I believe that the initialization from pre-trained weights without memory could make training much stable in addition to the residual idea.