google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.45k stars 271 forks source link

[mlperf/4.1] enable shard_in_read for large scaling training #839

Closed ZhiyuLi-goog closed 3 weeks ago

ZhiyuLi-goog commented 3 weeks ago

@anfals could you take a look. We need shard_in_read for v5p-12288 run.