google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.45k stars 271 forks source link

Mlperf/4.1 grain #840

Open aireenmei opened 3 weeks ago

aireenmei commented 3 weeks ago
aireenmei commented 3 weeks ago

The padding batch issue for eval is fixed now. To use this feature, set eval_steps to the just enough to finish the eval data (the data iter will generate empty examples until eval_steps met). For mlperf eval data, set eval_steps = math.ceil(5700/global_batch_size) and check if the output total_weights==11590004. Add more eval_steps if needed for total_weights to reach that value.