HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 5 forks source link

MuP Normalization #60

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

Currently, we apply MuP as a per-layer learning rate scale, which leads to faster and more memory-efficient training than their recommended method of initializing to a larger value and multiplying the outputs of every layer with a constant scalar.\ However, we can further improve the speed of our model by fusing MuParametrization's scales with those of normalization. This way, instead of training with mean=0+std=1, we'd train with mean=0+l2norm=1. This way, our new tensors would have a standard deviation of 1/sqrt(numel), which pushes them ever close to 0 - a region where floats become increasingly accurate. Computing this is very cheap as well, as we simply have to compute sqrt(sum(x^2)) instead of sqrt(mean(x^2)), or, in other words, remove a scalar multiplication.

ClashLuke commented 2 years ago

solved by #63