The condition for JAX distributed initialize needed to be reordered to allow triggering GPU and CPU initializations.
The following condition will evaluate to true due to base.yaml defaults, unless overridden from command line.
if (
raw_keys["enable_checkpointing"]
and raw_keys["async_checkpointing"]
and raw_keys["compile_topology_num_slices"] == -1
and not raw_keys["enable_single_controller"]
)
get_num_slices(raw_keys): has some new logic that computes number of slices and number of devices in a slice. Adapting this logic to suit CPUs ( slices have no meaning for CPUs, because they do not support hierarchical network. Hence setting num_slices to 1 and allowing existing ICI parallelism logic in max_utils).
Testing on multiprocess CPUs -
I tested standalone_checkpointer.py end-to-end (on 2 nodepools, with 2 hosts each) to verify this change.
The condition for JAX distributed initialize needed to be reordered to allow triggering GPU and CPU initializations. The following condition will evaluate to true due to base.yaml defaults, unless overridden from command line.
get_num_slices(raw_keys):
has some new logic that computes number of slices and number of devices in a slice. Adapting this logic to suit CPUs ( slices have no meaning for CPUs, because they do not support hierarchical network. Hence setting num_slices to 1 and allowing existing ICI parallelism logic in max_utils).Testing on multiprocess CPUs - I tested standalone_checkpointer.py end-to-end (on 2 nodepools, with 2 hosts each) to verify this change.