google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.44k stars 263 forks source link

Add compute_axis_order #703

Closed morgandu closed 2 months ago