the indexing performed using the TRIL indices was not supported by jax originally (or maybe I was doing something wrong).
Whatever the case, it was now possible to significantly clean up the GM implementation. Training results and timings are exactly the same.
the indexing performed using the TRIL indices was not supported by jax originally (or maybe I was doing something wrong). Whatever the case, it was now possible to significantly clean up the GM implementation. Training results and timings are exactly the same.