AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

Optimize overhead right before the first train_step #842

Closed ZhiyuLi-goog closed 1 month ago