Closed ClashLuke closed 2 years ago
I'm currently testing the performance of this modification in this wandb run (which uses this as a baseline). It seems like the step-time is not much worse. However, I could not activate FP64 matrix inversion without changing all defaults to FP64 as Jax' JAX_DEFAULT_DTYPE_BITS
doesn't do anything. FP64 has to be fixed before merging this PR.
RuntimeError: UNIMPLEMENTED: While rewriting computation to not contain X64 element types, XLA encountered an HLO for which this rewriting is not implemented: %reduce-scatter.4431 = s64[1]{0} reduce-scatter(s64[8]{0} %iota.4430), replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, to_apply=%region_16.288, metadata={op_name="pmap(jitless_step)/jit(main)/body/reduce_scatter[axis_name=model_parallel scatter_dimension=0 axis_index_groups=None axis_size=8 tiled=False]" source_file="/home/ubuntu/HomebrewNLP-Jax/src/model.py" source_line=331}
This needs outside input.
It's backwards compatible and needs to be merged as the last couple of commits fixed the dataset. This merge is necessary mainly for our sweeps to work well but unrelated to whether shampoo works well. A separate PR will explore adding RMSProp, and another will attempt to remove it.
Closes #15