Closed gramesh-amd closed 1 week ago
@ZhiyuLi-goog
The paxml ckpt is around 1.8 TB and the converted maxtext ckpt is around 449 GB. I'm guessing this is due to some compression?
Not. We should expect the exact same.
gsutil du -h gs://path/to/dir/checkpoints/4000/items/
....
....
1.75 TiB gs://path/to/dir/checkpoints/4000/items/
In a distributed checkpoint saving scenario, each device has access only to tensors that are accessible locally to it. A quick check is to see if 1.75TiB divided by the number of devices is approximately equal to 449GB, which represents the checkpoint size accessible locally on each device.
This is a good example for ckpt conversion and loading except the model is a test one:
https://github.com/google/maxtext/blob/main/end_to_end/tpu/test_gpt3.sh
./google-cloud-sdk/bin/gsutil -m cp -r gs://maxtext-gpt3/ckpt_test . AccessDeniedException: 403 gowtham.ramesh@amd.com does not have storage.objects.list access to the Google Cloud Storage bucket. Permission 'storage.objects.list' denied on resource (or it may not exist). CommandException: 1 file/object could not be transferred.
Could you give me permission to this bucket as well?
The checkpoint is also an artifacts, let me turn to google's internal team tomorrow for double check.
Just want to highlight:
--base-output-directory=${OUTPUT_PATH}
should be a gcs bucket instead of a local file path. This allows each device to save its partial checkpoints to the GCS bucket, which will then contain the complete, merged checkpoint.--run-name=${RUN_NAME}
and --base-output-directory=${OUTPUT_PATH}
, the ckpt should be found and loaded correctly.We cannot load the checkpoint converted using the either the main branch or the code that was checked into the MLPerf repo with your MLPerf training submission. The main branch of this submission has a problem due to a dimension mismatch - I think this is because the conversion script does not support the pipeline parallellism dimension.
When we try to use the code from the mlperf trainig submission (in the maxtext_fork directory in the conversion) ti convert and load the checkpoint using either load_full_state_path or base_output_directory, we get a dictionary mismatch error like this (note that we have a local file path, but this path is unique for each machine as it is a local disk not on NFS): 1: I0911 08:32:54.909472 139922234083136 checkpointer.py:227] Restoring checkpoint from /mnt/m2m_nobackup/users/user/gpt3-conversion-forked/checkpoints/4000. 1: Traceback (most recent call last): 1: File "/maxtext_fork/MaxText/train.py", line 558, in 1: app.run(main) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/absl/app.py", line 308, in run 1: _run_main(main, args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main 1: sys.exit(main(argv)) 1: File "/maxtext_fork/MaxText/train.py", line 554, in main 1: train_loop(config) 1: File "/maxtext_fork/MaxText/train.py", line 365, in train_loop 1: mesh, learning_rate_schedule, data_iterator, eval_data_iterator, state, tx) = setup_train_loop(config) 1: File "/maxtext_fork/MaxText/train.py", line 349, in setup_train_loop 1: state, state_mesh_annotations, data_iterator = max_utils.setup_training_state(model, data_iterator, 1: File "/maxtext_fork/MaxText/max_utils.py", line 371, in setup_training_state 1: return setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, is_training) 1: File "/maxtext_fork/MaxText/max_utils.py", line 396, in setup_initial_state 1: restored, raw_params = checkpointing.load_state_if_possible(checkpoint_manager, 1: File "/maxtext_fork/MaxText/checkpointing.py", line 103, in load_state_if_possible 1: return checkpoint_manager.restore(latest_step, 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1187, in restore 1: restored = self._checkpointer.restore(restore_directory, args=args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 229, in restore 1: restored = self._handler.restore(directory, args=ckpt_args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py", line 505, in restore 1: restored[item_name] = handler.restore( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/standard_checkpoint_handler.py", line 199, in restore 1: return self._impl.restore( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 762, in restore 1: return self._handler_impl.restore(directory, args=args) 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 619, in restore 1: restored_item = asyncio.run( 1: File "/pyenv/versions/3.10.14/lib/python3.10/asyncio/runners.py", line 44, in run 1: return loop.run_until_complete(main) 1: File "/pyenv/versions/3.10.14/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete 1: return future.result() 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 470, in _maybe_deserialize 1: batch_requests = batched_serialization_requests( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 187, in batched_serialization_requests 1: jax.tree_util.tree_map_with_path( 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/tree_util.py", line 1169, in tree_map_with_path 1: all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] 1: File "/pyenv/versions/3.10.14/lib/python3.10/site-packages/jax/_src/tree_util.py", line 1169, in 1: all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] 1: ValueError: Dict key mismatch; expected keys: ['decoder_norm', 'layers', 'position_embedder']; dict: {'decoder_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'layers_0': {'mlp': {'mlp_layer_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'a 1: utoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'wi': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'),)), global_shape=(49152,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSp 1: ec(('fsdp', 'sequence'), ('fsdp_transpose', 'tensor', 'autoregressive'))), global_shape=(12288, 49152))}, 'wo': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'), ('fsdp', 'sequence'))), global_shape=(49152, 12288))}}, 'pre_self_attention_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1: 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'self_attention': {'out': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autore 1: gressive': 1), spec=PartitionSpec(('tensor', 'autoregressive'), (), ('fsdp', 'fsdp_transpose', 'sequence'))), global_shape=(96, 128, 12288))}, 'qkv_proj': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(None, ('tensor', 'autoregressive'), ())), global_shape=(3, 96, 128)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'), None, ('tensor', 'autoregressive'), ())), global_shape=(12288, 3, 96, 128))}}}, 'layers_1': {'mlp': {'mlp_layer_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=Named 1: Sharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'wi': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'),)), global_shape=(49152,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp' 1: : 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'sequence'), ('fsdp_transpose', 'tensor', 'autoregressive'))), global_shape=(12288, 49152))}, 'wo': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'), ('fsdp', 'sequence'))), global_shape=(49152, 12288))}}, 'pre_self_attention_norm': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=N 1: one, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'self_attention': {'out': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedShard 1: ing(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('tensor', 'autoregressive'), (), ('fsdp', 'fsdp_transpose', 'sequence'))), global_shape=(96, 128, 12288))}, 'qkv_proj': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(None, ('tensor', 'autoregressive'), ())), global_shape=(3, 96, 128)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'), None, ('tensor', 'autoregressive'), ())), global_shape=(12288, 3, 96, 128))}}}, 'layers_10': {'mlp': {'mlp_layer_norm': {'bias': ArrayRestoreArgs(resto 1: re_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'wi': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'),)), global_shape=(49152,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=d 1: type('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'sequence'), ('fsdp_transpose', 'tensor', 'autoregressive'))), global_shape=(12288, 49152))}, 'wo': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp_transpose', 'tensor', 'autoregressive'), ('fsdp', 'sequence'))), global_shape=(49152, 12288))}}, 'pre_self_attention_norm': {'bias': Arr 1: ayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,))}, 'self_attention': {'out': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'),)), global_shape=(12288,)), 'kernel': ArrayRestoreArgs(restore_ty 1: pe=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('tensor', 'autoregressive'), (), ('fsdp', 'fsdp_transpose', 'sequence'))), global_shape=(96, 128, 12288))}, 'qkv_proj': {'bias': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(None, ('tensor', 'autoregressive'), ())), global_shape=(3, 96, 128)), 'kernel': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('data': 1, 'fsdp': 16, 'fsdp_transpose': 1, 'sequence': 1, 'tensor': 1, 'autoregressive': 1), spec=PartitionSpec(('fsdp', 'fsdp_transpose', 'sequence'), None, ('tensor', 'autoregressive'), ())), global_shape=(1 1: 2288, 3, 96, 128))}}}, (I have truncated the error message here - it lists the contents of the dictionary)
We are able to start training from random weights, save a checkpoint, and resume training in another job, so we believe that the problem is in the converting the checkpoint from paxml to orbax
Hi @gabeweisz
We cannot load the checkpoint converted using the either the main branch or the code that was checked into the MLPerf repo with your MLPerf training submission.
If you are using the MLPerf training submission branch, you should be able to run training code only within the same branch. There's compatibility issue since there are changes of layer structure which causes key mismatch.
note that we have a local file path, but this path is unique for each machine as it is a local disk not on NFS
Thank you for let me know. Is it possible to save to a global accessible directory, either a gcs bucket or a mounted one.
The main branch of this submission has a problem due to a dimension mismatch - I think this is because the conversion script does not support the pipeline parallellism dimension.
Orbax checkpoint is parallelism agnostic as long as we match the key of the checkpoint, i.e. say you should be able to save it in fsdp while run it with fsdp and TP 2D sharding. I can take another try on pipeline parallelism too.
It would be great if you have some repo scripts to share including the branch you are using.
This is all running in Google's official submission branch that we obtained from https://github.com/mlcommons/training_results_v4.0/tree/main/Google/benchmarks/gpt3/implementations/maxtext/maxtext_fork
We did also try a single globally-visible output directory, and it did not make a difference to this functionality.
We had to make minor changes to the checkpoint generation script to turn on FSDP across nodes (since we are running 4 GPU nodes each with 8 GPUs) and to load the paxml checkpoint from a local directory.
I am attaching the modified version of the checkpoint creation script and the model config file (adapted from base.yml that came with the branch) that we used. Sorry, GitHub made me change the extensions before uploading them
We're not using any unusual parameters to the scripts - just passing the local paxml and base output directories to the conversion script, and calling train.py with gpt3_175b_gpu.yml and the location of the checkpoint (we tried load_full_state_path, load_parameters_path, and base_output_directory)
From looking at the dictionary error - it appears that the script is expecting a key called "layers", and instead we have keys for "layers_0", "layers_1", ... - so something is flattened when it should not be.
Or maybe it is a jax versioning issue? We are using 0.4.30.
Thank you for the info:
Could you try run both conversion and training scripts with
scan_layers: True
I saw your change in your configurations.
This will combine all the weight from different layers into a single big tensor.
There's no layers_0
... layers_95
per each layer but a single concatenated tensor with just an individual key layers
.
scan_layers=True
would help a lot in decreasing compilation time as well.
That seems to be the key - we used scan_layers=true to create the checkpoint but were using scan_layers=false for training, and that does not seem to work. Thank you for helping us! We'll do a bit more testing and then close the ticket - it seems to work both with the branch that Google used for MLPerf training submission and for the head of the MaxText main branch
Happy to hear that. Any time!
Hello,
I managed to convert the gpt3 pax ckpt using convert_gpt3_ckpt_from_paxml.py script with 4 nodes (8 gpus in each node) and the follwing base_args:
The paxml ckpt is around 1.8 TB and the converted maxtext ckpt is around 449 GB. I'm guessing this is due to some compression?
When I try to reload this ckpt with the following config:
^ note that im using load_full_state_path. If i just specify run_name and base_output_directory without load_full_state_path or load_parameters_path, maxtext fails to find the checkpoint:
When i use load_full_state_path: "4000/items", i get the following error: Dict key mismatch; expected keys: [...] ckpt-conversion-loading-error.txt
^ see the full error here