AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.47k stars 275 forks source link

new features with distributed training framework #821

Closed bernardhan33 closed 1 month ago

bernardhan33 commented 1 month ago

Features

  1. Support of modeling per_step_interval.
  2. Support of modeling max_steps.
  3. Support of listing the dataset directory instead of hard-coding the bucket name from the run flags.

(1) and (2) are coming from the design doc. (3) is a bug identified from a conversation with the HNS team that the current hard-coded values prevent the benchmark from easily run against different bucket names.

Internal CL to update the README and the yaml file: cl/662264972.

Tested by

  1. Setting per_step_interval to 1 second and confirming that the per step time is roughly 1 second with the exception of those steps whose data loading time takes longer.
  2. Setting max_steps to make sure that the training can stop gracefully after the global step is met, and that the metrics are recorded correctly.

Next steps

  1. Once this PR and the CL are merged, I'll build the image and upload to gcr.io/gcs-tess/distributed_pytorch_training_benchmark.
  2. Once we start to move this to the new simpler framework, I'll add unit tests to cover the various features.