It works! Does a bit better than the default model: 85.6% vs 85.93% and gets rid of those awkward norms.
I also tried changing other Linear layers in SPT and in the output projection but that breaks the model. So these changes only seem to make sense inside the transformer.
From Apple paper: "Stabilizing Transformer Training by Preventing Attention Entropy Collapse" https://github.com/apple/ml-sigma-reparam
Here's my experiment:
https://github.com/catid/cifar10deepspeed/commit/72a0b78d3ff4a1317d53b4a1d7710540a357eb17
It works! Does a bit better than the default model: 85.6% vs 85.93% and gets rid of those awkward norms.
I also tried changing other Linear layers in SPT and in the output projection but that breaks the model. So these changes only seem to make sense inside the transformer.
On Twitter the authors said you can apply it to other layers like this: https://github.com/catid/cifar10deepspeed/pull/1/files But it doesn't work in my script for whatever reason and seems like a bit of a hack.