HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
46 stars 6 forks source link

Shampoo Refactor #36

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

The most significant changes are that shampoo now uses 1 - beta like the other optimisers and that I shrank the core algorithm into 15 LOC. Additionally, I removed ctx.parameter_dims, which was used to save the dimension names for each allocated buffer as it's not needed anymore.

ClashLuke commented 2 years ago

For improved stability, I also removed grafting onto SGD. #35 should explore grafting onto other (more stable) optimisers like SM3 and RMSProp.

The model is stable and outperforms the baseline by a good amount. The blue run illustrates our convergence without merging this PR. grafik