HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 5 forks source link

LpNorm + ScaleNorm #77

Closed ClashLuke closed 1 year ago

ClashLuke commented 1 year ago

ScaleNorm: grafik

LayerNorm: grafik

Otherwise, both models have identical settings, so purely by switching LayerNorm with ScaleNorm (both L2), the model becomes 25% faster while achieving the same (or better) convergence: grafik

L1-Norm has the same speed but worse loss initially: grafik grafik

Additionally, this PR makes grad checks tougher to pass, as all models don't just get random parameters but also a random output gradient. This changed input ensures that we use the output gradient correctly, as our custom_grad functions could otherwise ignore it.

ClashLuke commented 1 year ago

All tests pass: grafik

LayerNorm has a 3% higher loss than L2-ScaleNorm: grafik

L1-ScaleNorm has an 8% higher loss than L2-ScaleNorm: grafik

Regardless of that performance, I added an option to change the normalization's power and optionally centralize it.