AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

RA update works for all axes orders #859

Closed patemotter closed 3 weeks ago

patemotter commented 3 weeks ago

Previous version of update was expecting a specific ordering, which is not always guaranteed. This change will update the cache correctly regardless of the axes ordering.