HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
46 stars 6 forks source link

Shampoo Optimizer #15

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

Second-order optimizers such as K-Fac, LBFGS and AdaHessian promise significantly improved convergence rates at horrific memory costs. Scalable Shampoo promises a low memory footprint and vectorisable computation while retaining the convergence advantage of other second-order optimisers. Adding it to our code could reduce training time by 10% or even up to an order of magnitude.\ This issue is about implementing shampoo (reference might help), running a hyperparameter sweep to find its best configuration and comparing the best possible runtime with our previous best.

ClashLuke commented 2 years ago

I'm working on this now.\ The original jax-based implementation (without Optax) seems easiest to integrate. We can consider adding optax support and integrating their Optax version if it works well. Optax-Shampoo seems to support quantisation and many other features, making other downstream tasks like #14 easier.