HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 6 forks source link

Shampoo Optimizer #34

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

Closes #15

ClashLuke commented 2 years ago

I'm currently testing the performance of this modification in this wandb run (which uses this as a baseline). It seems like the step-time is not much worse. However, I could not activate FP64 matrix inversion without changing all defaults to FP64 as Jax' JAX_DEFAULT_DTYPE_BITS doesn't do anything. FP64 has to be fixed before merging this PR.

ClashLuke commented 2 years ago
RuntimeError: UNIMPLEMENTED: While rewriting computation to not contain X64 element types, XLA encountered an HLO for which this rewriting is not implemented: %reduce-scatter.4431 = s64[1]{0} reduce-scatter(s64[8]{0} %iota.4430), replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, to_apply=%region_16.288, metadata={op_name="pmap(jitless_step)/jit(main)/body/reduce_scatter[axis_name=model_parallel scatter_dimension=0 axis_index_groups=None axis_size=8 tiled=False]" source_file="/home/ubuntu/HomebrewNLP-Jax/src/model.py" source_line=331}

This needs outside input.

ClashLuke commented 2 years ago

It's backwards compatible and needs to be merged as the last couple of commits fixed the dataset. This merge is necessary mainly for our sweeps to work well but unrelated to whether shampoo works well. A separate PR will explore adding RMSProp, and another will attempt to remove it.

ClashLuke commented 2 years ago

WandB runs of Shampoo vs SM3


Shampoo is a bit slower at 38.8s/step grafik compared to SM3's 37.4s/step grafik


It also seems like shampoo converges a little bit worse grafik

That's why I'd say that disabling it by default seems a sensible idea to keep the code there and allow for a faster merge.