google / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.39k stars 247 forks source link

Fix Mesh setup for multiprocess CPUs. #723

Closed RoshaniN closed 4 weeks ago

RoshaniN commented 1 month ago
  1. The condition for JAX distributed initialize needed to be reordered to allow triggering GPU and CPU initializations. The following condition will evaluate to true due to base.yaml defaults, unless overridden from command line.

    if (
      raw_keys["enable_checkpointing"]
      and raw_keys["async_checkpointing"]
      and raw_keys["compile_topology_num_slices"] == -1
      and not raw_keys["enable_single_controller"]
    )
    1. get_num_slices(raw_keys): has some new logic that computes number of slices and number of devices in a slice. Adapting this logic to suit CPUs ( slices have no meaning for CPUs, because they do not support hierarchical network. Hence setting num_slices to 1 and allowing existing ICI parallelism logic in max_utils).

    Testing on multiprocess CPUs - I tested standalone_checkpointer.py end-to-end (on 2 nodepools, with 2 hosts each) to verify this change.